diff options
Diffstat (limited to 'tests')
37 files changed, 2415 insertions, 1874 deletions
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 4003869ed6..236b608d58 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -50,13 +50,17 @@ class ApplicationServiceTestCase(unittest.TestCase): def test_regex_user_id_prefix_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_user_id_prefix_no_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" - self.assertFalse((yield self.service.is_interested(self.event))) + self.assertFalse( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_member_is_checked(self): @@ -64,7 +68,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.sender = "@someone_else:matrix.org" self.event.type = "m.room.member" self.event.state_key = "@irc_foobar:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_id_match(self): @@ -72,7 +78,9 @@ class ApplicationServiceTestCase(unittest.TestCase): _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_id_no_match(self): @@ -80,19 +88,26 @@ class ApplicationServiceTestCase(unittest.TestCase): _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" - self.assertFalse((yield self.service.is_interested(self.event))) + self.assertFalse( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_alias_match(self): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room.return_value = [ - "#irc_foobar:matrix.org", - "#athing:matrix.org", - ] - self.store.get_users_in_room.return_value = [] - self.assertTrue((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#irc_foobar:matrix.org", "#athing:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertTrue( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) def test_non_exclusive_alias(self): self.service.namespaces[ApplicationService.NS_ALIASES].append( @@ -135,12 +150,17 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room.return_value = [ - "#xmpp_foobar:matrix.org", - "#athing:matrix.org", - ] - self.store.get_users_in_room.return_value = [] - self.assertFalse((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertFalse( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) @defer.inlineCallbacks def test_regex_multiple_matches(self): @@ -149,9 +169,17 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"] - self.store.get_users_in_room.return_value = [] - self.assertTrue((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#irc_barfoo:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertTrue( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) @defer.inlineCallbacks def test_interested_in_self(self): @@ -161,19 +189,24 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.type = "m.room.member" self.event.content = {"membership": "invite"} self.event.state_key = self.service.sender - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_member_list_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) - self.store.get_users_in_room.return_value = [ - "@alice:here", - "@irc_fo:here", # AS user - "@bob:here", - ] - self.store.get_aliases_for_room.return_value = [] + # Note that @irc_fo:here is the AS user. + self.store.get_users_in_room.return_value = defer.succeed( + ["@alice:here", "@irc_fo:here", "@bob:here"] + ) + self.store.get_aliases_for_room.return_value = defer.succeed([]) self.event.sender = "@xmpp_foobar:matrix.org" self.assertTrue( - (yield self.service.is_interested(event=self.event, store=self.store)) + ( + yield defer.ensureDeferred( + self.service.is_interested(event=self.event, store=self.store) + ) + ) ) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 52f89d3f83..68a4caabbf 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -25,6 +25,7 @@ from synapse.appservice.scheduler import ( from synapse.logging.context import make_deferred_yieldable from tests import unittest +from tests.test_utils import make_awaitable from ..utils import MockClock @@ -52,11 +53,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.store.get_appservice_state = Mock( return_value=defer.succeed(ApplicationServiceState.UP) ) - txn.send = Mock(return_value=defer.succeed(True)) + txn.send = Mock(return_value=make_awaitable(True)) self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events # txn made and saved @@ -77,7 +78,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events # txn made and saved @@ -98,11 +99,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): return_value=defer.succeed(ApplicationServiceState.UP) ) self.store.set_appservice_state = Mock(return_value=defer.succeed(True)) - txn.send = Mock(return_value=defer.succeed(False)) # fails to send + txn.send = Mock(return_value=make_awaitable(False)) # fails to send self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events @@ -144,7 +145,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=True) + txn.send = Mock(return_value=make_awaitable(True)) + txn.complete.return_value = make_awaitable(None) # wait for exp backoff self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) @@ -169,7 +171,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=False) + txn.send = Mock(return_value=make_awaitable(False)) + txn.complete.return_value = make_awaitable(None) self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) @@ -182,7 +185,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.assertEquals(3, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) self.assertEquals(0, self.callback.call_count) - txn.send = Mock(return_value=True) # successfully send the txn + txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) self.assertEquals(1, txn.send.call_count) # new mock reset call count diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index f9ce609923..e0ad8e8a77 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -102,11 +102,10 @@ class KeyringTestCase(unittest.HomeserverTestCase): } persp_deferred = defer.Deferred() - @defer.inlineCallbacks - def get_perspectives(**kwargs): + async def get_perspectives(**kwargs): self.assertEquals(current_context().request, "11") with PreserveLoggingContext(): - yield persp_deferred + await persp_deferred return persp_resp self.http_client.post_json.side_effect = get_perspectives @@ -355,7 +354,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): } signedjson.sign.sign_json(response, SERVER_NAME, testkey) - def get_json(destination, path, **kwargs): + async def get_json(destination, path, **kwargs): self.assertEqual(destination, SERVER_NAME) self.assertEqual(path, "/_matrix/key/v2/server/key1") return response @@ -444,7 +443,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect a perspectives-server key query """ - def post_json(destination, path, data, **kwargs): + async def post_json(destination, path, data, **kwargs): self.assertEqual(destination, self.mock_perspective_server.server_name) self.assertEqual(path, "/_matrix/key/v2/query") @@ -580,14 +579,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # remove the perspectives server's signature response = build_response() del response["signatures"][self.mock_perspective_server.server_name] - self.http_client.post_json.return_value = {"server_keys": [response]} keys = get_key_from_perspectives(response) self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig") # remove the origin server's signature response = build_response() del response["signatures"][SERVER_NAME] - self.http_client.post_json.return_value = {"server_keys": [response]} keys = get_key_from_perspectives(response) self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig") diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 640f5f3bce..3a80626224 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -41,8 +41,10 @@ class TestEventContext(unittest.HomeserverTestCase): serialize/deserialize. """ - event, context = create_event( - self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + event, context = self.get_success( + create_event( + self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + ) ) self._check_serialize_deserialize(event, context) @@ -51,12 +53,14 @@ class TestEventContext(unittest.HomeserverTestCase): """Test that an EventContext for a state event (with not previous entry) is the same after serialize/deserialize. """ - event, context = create_event( - self.hs, - room_id=self.room_id, - type="m.test", - sender=self.user_id, - state_key="", + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + sender=self.user_id, + state_key="", + ) ) self._check_serialize_deserialize(event, context) @@ -65,13 +69,15 @@ class TestEventContext(unittest.HomeserverTestCase): """Test that an EventContext for a state event (which replaces a previous entry) is the same after serialize/deserialize. """ - event, context = create_event( - self.hs, - room_id=self.room_id, - type="m.room.member", - sender=self.user_id, - state_key=self.user_id, - content={"membership": "leave"}, + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.room.member", + sender=self.user_id, + state_key=self.user_id, + content={"membership": "leave"}, + ) ) self._check_serialize_deserialize(event, context) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 0c9987be54..b8ca118716 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, room from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable class RoomComplexityTests(unittest.FederatingHomeserverTestCase): @@ -78,9 +79,40 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request failed with a SynapseError saying the resource limit was + # exceeded. + f = self.get_failure(d, SynapseError) + self.assertEqual(f.value.code, 400, f.value) + self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def test_join_too_large_admin(self): + # Check whether an admin can join if option "admins_can_join" is undefined, + # this option defaults to false, so the join should fail. + + u1 = self.register_user("u1", "pass", admin=True) + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + handler.federation_handler.do_invite_join = Mock( + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -116,9 +148,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) + fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) # Artificially raise the complexity @@ -141,3 +173,81 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): f = self.get_failure(d, SynapseError) self.assertEqual(f.value.code, 400) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + +class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): + # Test the behavior of joining rooms which exceed the complexity if option + # limit_remote_rooms.admins_can_join is True. + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self): + config = super().default_config() + config["limit_remote_rooms"] = { + "enabled": True, + "complexity": 0.05, + "admins_can_join": True, + } + return config + + def test_join_too_large_no_admin(self): + # A user which is not an admin should not be able to join a remote room + # which is too complex. + + u1 = self.register_user("u1", "pass") + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + handler.federation_handler.do_invite_join = Mock( + return_value=make_awaitable(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request failed with a SynapseError saying the resource limit was + # exceeded. + f = self.get_failure(d, SynapseError) + self.assertEqual(f.value.code, 400, f.value) + self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def test_join_too_large_admin(self): + # An admin should be able to join rooms where a complexity check fails. + + u1 = self.register_user("u1", "pass", admin=True) + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + handler.federation_handler.do_invite_join = Mock( + return_value=make_awaitable(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request success since the user is an admin + self.get_success(d) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 1a9bd5f37d..5f512ff8bf 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -26,31 +26,34 @@ from synapse.rest import admin from synapse.rest.client.v1 import login from synapse.types import JsonDict, ReadReceipt +from tests.test_utils import make_awaitable from tests.unittest import HomeserverTestCase, override_config class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): + mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) + # Ensure a new Awaitable is created for each call. + mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable( + ["test", "host2"] + ) return self.setup_test_homeserver( - state_handler=Mock(spec=["get_current_hosts_in_room"]), + state_handler=mock_state_handler, federation_transport_client=Mock(spec=["send_transaction"]), ) @override_config({"send_federation": True}) def test_send_receipts(self): - mock_state_handler = self.hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) - mock_send_transaction.return_value = defer.succeed({}) + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() @@ -81,19 +84,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" - mock_state_handler = self.hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) - mock_send_transaction.return_value = defer.succeed({}) + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() @@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): receipt = ReadReceipt( "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() mock_send_transaction.assert_not_called() @@ -164,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): return self.setup_test_homeserver( - state_handler=Mock(spec=["get_current_hosts_in_room"]), federation_transport_client=Mock(spec=["send_transaction"]), ) @@ -174,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): return c def prepare(self, reactor, clock, hs): - # stub out get_current_hosts_in_room - mock_state_handler = hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - # stub out get_users_who_share_room_with_user so that it claims that # `@user2:host2` is in the room def get_users_who_share_room_with_user(user_id): diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index ebabe9a7d6..628f7d8db0 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.handlers.appservice import ApplicationServicesHandler +from tests.test_utils import make_awaitable from tests.utils import MockClock from .. import unittest @@ -117,7 +118,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self._mkservice_alias(is_interested_in_alias=False), ] - self.mock_as_api.query_alias.return_value = defer.succeed(True) + self.mock_as_api.query_alias.return_value = make_awaitable(True) self.mock_store.get_app_services.return_value = services self.mock_store.get_association_from_room_alias.return_value = defer.succeed( Mock(room_id=room_id, servers=servers) @@ -135,7 +136,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): def _mkservice(self, is_interested): service = Mock() - service.is_interested.return_value = defer.succeed(is_interested) + service.is_interested.return_value = make_awaitable(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" return service diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 00bb776271..bc0c5aefdc 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -16,8 +16,6 @@ from mock import Mock -from twisted.internet import defer - import synapse import synapse.api.errors from synapse.api.constants import EventTypes @@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room from synapse.types import RoomAlias, create_requester from tests import unittest +from tests.test_utils import make_awaitable class DirectoryTestCase(unittest.HomeserverTestCase): @@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self): - self.mock_federation.make_query.return_value = defer.succeed( + self.mock_federation.make_query.return_value = make_awaitable( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 4f1347cd25..d70e1fc608 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -24,6 +24,7 @@ from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -138,7 +139,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_other_name(self): - self.mock_federation.make_query.return_value = defer.succeed( + self.mock_federation.make_query.return_value = make_awaitable( {"displayname": "Alice"} ) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 954e059e76..69945a8f98 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -67,6 +67,14 @@ def get_connection_factory(): return test_server_connection_factory +# Once Async Mocks or lambdas are supported this can go away. +def generate_resolve_service(result): + async def resolve_service(_): + return result + + return resolve_service + + class MatrixFederationAgentTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() @@ -373,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the certificate on the server doesn't match the hostname """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv1"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv1/foo/bar") @@ -456,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name has no port, no SRV, and no well-known """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -510,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the .well-known delegates elsewhere """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -572,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -661,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -717,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # the config left to the default, which will not trust it (since the # presented cert is signed by a test CA) - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" config = default_config("test", parse=True) @@ -764,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when there is a single SRV record """ - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443) - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"srvtarget", port=8443)] + ) self.reactor.lookups["srvtarget"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -819,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443) - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"srvtarget", port=8443)] + ) self._handle_well_known_connection( client_factory, @@ -861,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_servername(self): """test the behaviour when the server name has idna chars in""" - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) # the resolver is always called with the IDNA hostname as a native string. self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" @@ -922,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_srv_target(self): """test the behaviour when the target of a SRV record has idna chars""" - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"xn--trget-3qa.com", port=8443) # târget.com - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com + ) self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") @@ -1087,11 +1095,12 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_srv_fallbacks(self): """Test that other SRV results are tried if the first one fails. """ - - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"target.com", port=8443), - Server(host=b"target.com", port=8444), - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [ + Server(host=b"target.com", port=8443), + Server(host=b"target.com", port=8444), + ] + ) self.reactor.lookups["target.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index babc201643..fee2985d35 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError from twisted.names import dns, error from synapse.http.federation.srv_resolver import SrvResolver -from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context +from synapse.logging.context import LoggingContext, current_context from tests import unittest from tests.utils import MockClock @@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase): with LoggingContext("one") as ctx: resolve_d = resolver.resolve_service(service_name) - - self.assertNoResult(resolve_d) - - # should have reset to the sentinel context - self.assertIs(current_context(), SENTINEL_CONTEXT) - - result = yield resolve_d + result = yield defer.ensureDeferred(resolve_d) # should have restored our context self.assertIs(current_context(), ctx) @@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase): cache = {service_name: [entry]} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) dns_client_mock.lookupService.assert_called_once_with(service_name) @@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase): dns_client=dns_client_mock, cache=cache, get_time=clock.time ) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) self.assertFalse(dns_client_mock.lookupService.called) @@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase): resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) with self.assertRaises(error.DNSServerError): - yield resolver.resolve_service(service_name) + yield defer.ensureDeferred(resolver.resolve_service(service_name)) @defer.inlineCallbacks def test_name_error(self): @@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) self.assertEquals(len(servers), 0) self.assertEquals(len(cache), 0) @@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - resolve_d = resolver.resolve_service(service_name) - self.assertNoResult(resolve_d) + # Old versions of Twisted don't have an ensureDeferred in failureResultOf. + resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name)) # returning a single "." should make the lookup fail with a ConenctError lookup_deferred.callback( @@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - resolve_d = resolver.resolve_service(service_name) - self.assertNoResult(resolve_d) + # Old versions of Twisted don't have an ensureDeferred in successResultOf. + resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name)) lookup_deferred.callback( ( diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index fff4f0cbf4..ac598249e4 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -58,7 +58,9 @@ class FederationClientTests(HomeserverTestCase): @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: - fetch_d = self.cl.get_json("testserv:8008", "foo/bar") + fetch_d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar") + ) # Nothing happened yet self.assertNoResult(fetch_d) @@ -120,7 +122,9 @@ class FederationClientTests(HomeserverTestCase): """ If the DNS lookup returns an error, it will bubble up. """ - d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + ) self.pump() f = self.failureResultOf(d) @@ -128,7 +132,9 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value.inner_exception, DNSLookupError) def test_client_connection_refused(self): - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -154,7 +160,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -184,7 +192,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -226,7 +236,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a blacklisted IPv4 address # ------------------------------------------------------ # Make the request - d = cl.get_json("internal:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000)) # Nothing happened yet self.assertNoResult(d) @@ -244,7 +254,9 @@ class FederationClientTests(HomeserverTestCase): # Try making a POST request to a blacklisted IPv6 address # ------------------------------------------------------- # Make the request - d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + ) # Nothing has happened yet self.assertNoResult(d) @@ -263,7 +275,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a non-blacklisted IPv4 address # ---------------------------------------------------------- # Make the request - d = cl.post_json("fine:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000)) # Nothing has happened yet self.assertNoResult(d) @@ -286,7 +298,7 @@ class FederationClientTests(HomeserverTestCase): request = MatrixFederationRequest( method="GET", destination="testserv:8008", path="foo/bar" ) - d = self.cl._send_request(request, timeout=10000) + d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000)) self.pump() @@ -310,7 +322,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -342,7 +356,9 @@ class FederationClientTests(HomeserverTestCase): requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -395,7 +411,9 @@ class FederationClientTests(HomeserverTestCase): See test_client_requires_trailing_slashes() for context. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -432,7 +450,11 @@ class FederationClientTests(HomeserverTestCase): self.failureResultOf(d) def test_client_sends_body(self): - self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) + defer.ensureDeferred( + self.cl.post_json( + "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"} + ) + ) self.pump() @@ -453,7 +475,7 @@ class FederationClientTests(HomeserverTestCase): def test_closes_connection(self): """Check that the client closes unused HTTP connections""" - d = self.cl.get_json("testserv:8008", "foo/bar") + d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) self.pump() diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 1a88c7fb80..0b5204654c 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -366,7 +366,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): state_handler = self.hs.get_state_handler() context = self.get_success(state_handler.compute_event_context(event)) - self.master_store.add_push_actions_to_staging( - event.event_id, {user_id: actions for user_id, actions in push_actions} + self.get_success( + self.master_store.add_push_actions_to_staging( + event.event_id, {user_id: actions for user_id, actions in push_actions} + ) ) return event, context diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 097e1653b4..c9998e88e6 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -119,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase): OTHER_USER = "@other_user:localhost" # have the user join - inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + self.get_success( + inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + ) # Update existing power levels with mod at PL50 pls = self.helper.get_state( @@ -157,14 +159,16 @@ class EventsStreamTestCase(BaseStreamTestCase): # roll back all the state by de-modding the user prev_events = fork_point pls["users"][OTHER_USER] = 0 - pl_event = inject_event( - self.hs, - prev_event_ids=prev_events, - type=EventTypes.PowerLevels, - state_key="", - sender=self.user_id, - room_id=self.room_id, - content=pls, + pl_event = self.get_success( + inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) ) # one more bit of state that doesn't get rolled back @@ -268,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase): # have the users join for u in user_ids: - inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + self.get_success( + inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + ) # Update existing power levels with mod at PL50 pls = self.helper.get_state( @@ -306,14 +312,16 @@ class EventsStreamTestCase(BaseStreamTestCase): pl_events = [] for u in user_ids: pls["users"][u] = 0 - e = inject_event( - self.hs, - prev_event_ids=prev_events, - type=EventTypes.PowerLevels, - state_key="", - sender=self.user_id, - room_id=self.room_id, - content=pls, + e = self.get_success( + inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) ) prev_events = [e.event_id] pl_events.append(e) @@ -434,13 +442,15 @@ class EventsStreamTestCase(BaseStreamTestCase): body = "event %i" % (self.event_count,) self.event_count += 1 - return inject_event( - self.hs, - room_id=self.room_id, - sender=sender, - type="test_event", - content={"body": body}, - **kwargs + return self.get_success( + inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_event", + content={"body": body}, + **kwargs + ) ) def _inject_state_event( @@ -459,11 +469,13 @@ class EventsStreamTestCase(BaseStreamTestCase): if body is None: body = "state event %s" % (state_key,) - return inject_event( - self.hs, - room_id=self.room_id, - sender=sender, - type="test_state_event", - state_key=state_key, - content={"body": body}, + return self.get_success( + inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_state_event", + state_key=state_key, + content={"body": body}, + ) ) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 8d4dbf232e..83f9aa291c 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -16,8 +16,6 @@ import logging from mock import Mock -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.events.builder import EventBuilderFactory from synapse.rest.admin import register_servlets_for_client_rest_resource @@ -25,6 +23,7 @@ from synapse.rest.client.v1 import login, room from synapse.types import UserID from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.test_utils import make_awaitable logger = logging.getLogger(__name__) @@ -46,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new event. """ mock_client = Mock(spec=["put_json"]) - mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", @@ -74,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new events. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -86,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -137,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new typing EDUs. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -149,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index b1a4decced..0f1144fe1e 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self.fetches = [] - def get_file(destination, path, output_stream, args=None, max_size=None): + async def get_file(destination, path, output_stream, args=None, max_size=None): """ Returns tuple[int,dict,str,int] of file length, response headers, absolute URI, and response code. @@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): d = Deferred() d.addCallback(write_to) self.fetches.append((d, destination, path, args)) - return make_deferred_yieldable(d) + return await make_deferred_yieldable(d) client = Mock() client.get_file = get_file diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 946f06d151..cec1cf928f 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1,1447 +1,1500 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 Dirk Klimpel -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import urllib.parse -from typing import List, Optional - -from mock import Mock - -import synapse.rest.admin -from synapse.api.errors import Codes -from synapse.rest.client.v1 import directory, events, login, room - -from tests import unittest - -"""Tests admin REST events for /rooms paths.""" - - -class ShutdownRoomTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - events.register_servlets, - room.register_servlets, - room.register_deprecated_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.event_creation_handler = hs.get_event_creation_handler() - hs.config.user_consent_version = "1" - - consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" - self.event_creation_handler._consent_uri_builder = consent_uri_builder - - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.other_user = self.register_user("user", "pass") - self.other_user_token = self.login("user", "pass") - - # Mark the admin user as having consented - self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) - - def test_shutdown_room_consent(self): - """Test that we can shutdown rooms with local users who have not - yet accepted the privacy policy. This used to fail when we tried to - force part the user from the old room. - """ - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Assert one user in room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([self.other_user], users_in_room) - - # Enable require consent to send events - self.event_creation_handler._block_events_without_consent_error = "Error" - - # Assert that the user is getting consent error - self.helper.send( - room_id, body="foo", tok=self.other_user_token, expect_code=403 - ) - - # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert there is now no longer anyone in the room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([], users_in_room) - - def test_shutdown_room_block_peek(self): - """Test that a world_readable room can no longer be peeked into after - it has been shut down. - """ - - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Enable world readable - url = "rooms/%s/state/m.room.history_visibility" % (room_id,) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - json.dumps({"history_visibility": "world_readable"}), - access_token=self.other_user_token, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert we can no longer peek into the room - self._assert_peek(room_id, expect_code=403) - - def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room. - """ - - url = "rooms/%s/initialSync" % (room_id,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - url = "events?timeout=0&room_id=" + room_id - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - -class DeleteRoomTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - events.register_servlets, - room.register_servlets, - room.register_deprecated_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.event_creation_handler = hs.get_event_creation_handler() - hs.config.user_consent_version = "1" - - consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" - self.event_creation_handler._consent_uri_builder = consent_uri_builder - - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.other_user = self.register_user("user", "pass") - self.other_user_tok = self.login("user", "pass") - - # Mark the admin user as having consented - self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) - - self.room_id = self.helper.create_room_as( - self.other_user, tok=self.other_user_tok - ) - self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id - - def test_requester_is_no_admin(self): - """ - If the user is not a server admin, an error 403 is returned. - """ - - request, channel = self.make_request( - "POST", self.url, json.dumps({}), access_token=self.other_user_tok, - ) - self.render(request) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - def test_room_does_not_exist(self): - """ - Check that unknown rooms/server return error 404. - """ - url = "/_synapse/admin/v1/rooms/!unknown:test/delete" - - request, channel = self.make_request( - "POST", url, json.dumps({}), access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - def test_room_is_not_valid(self): - """ - Check that invalid room names, return an error 400. - """ - url = "/_synapse/admin/v1/rooms/invalidroom/delete" - - request, channel = self.make_request( - "POST", url, json.dumps({}), access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "invalidroom is not a legal room ID", channel.json_body["error"], - ) - - def test_new_room_user_does_not_exist(self): - """ - Tests that the user ID must be from local server but it does not have to exist. - """ - body = json.dumps({"new_room_user_id": "@unknown:test"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertIn("new_room_id", channel.json_body) - self.assertIn("kicked_users", channel.json_body) - self.assertIn("failed_to_kick_users", channel.json_body) - self.assertIn("local_aliases", channel.json_body) - - def test_new_room_user_is_not_local(self): - """ - Check that only local users can create new room to move members. - """ - body = json.dumps({"new_room_user_id": "@not:exist.bla"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "User must be our own: @not:exist.bla", channel.json_body["error"], - ) - - def test_block_is_not_bool(self): - """ - If parameter `block` is not boolean, return an error - """ - body = json.dumps({"block": "NotBool"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) - - def test_purge_room_and_block(self): - """Test to purge a room and block it. - Members will not be moved to a new room and will not receive a message. - """ - # Test that room is not purged - with self.assertRaises(AssertionError): - self._is_purged(self.room_id) - - # Test that room is not blocked - self._is_blocked(self.room_id, expect=False) - - # Assert one user in room - self._is_member(room_id=self.room_id, user_id=self.other_user) - - body = json.dumps({"block": True}) - - request, channel = self.make_request( - "POST", - self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(None, channel.json_body["new_room_id"]) - self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) - self.assertIn("failed_to_kick_users", channel.json_body) - self.assertIn("local_aliases", channel.json_body) - - self._is_purged(self.room_id) - self._is_blocked(self.room_id, expect=True) - self._has_no_members(self.room_id) - - def test_purge_room_and_not_block(self): - """Test to purge a room and do not block it. - Members will not be moved to a new room and will not receive a message. - """ - # Test that room is not purged - with self.assertRaises(AssertionError): - self._is_purged(self.room_id) - - # Test that room is not blocked - self._is_blocked(self.room_id, expect=False) - - # Assert one user in room - self._is_member(room_id=self.room_id, user_id=self.other_user) - - body = json.dumps({"block": False}) - - request, channel = self.make_request( - "POST", - self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(None, channel.json_body["new_room_id"]) - self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) - self.assertIn("failed_to_kick_users", channel.json_body) - self.assertIn("local_aliases", channel.json_body) - - self._is_purged(self.room_id) - self._is_blocked(self.room_id, expect=False) - self._has_no_members(self.room_id) - - def test_shutdown_room_consent(self): - """Test that we can shutdown rooms with local users who have not - yet accepted the privacy policy. This used to fail when we tried to - force part the user from the old room. - Members will be moved to a new room and will receive a message. - """ - self.event_creation_handler._block_events_without_consent_error = None - - # Assert one user in room - users_in_room = self.get_success(self.store.get_users_in_room(self.room_id)) - self.assertEqual([self.other_user], users_in_room) - - # Enable require consent to send events - self.event_creation_handler._block_events_without_consent_error = "Error" - - # Assert that the user is getting consent error - self.helper.send( - self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 - ) - - # Test that room is not purged - with self.assertRaises(AssertionError): - self._is_purged(self.room_id) - - # Assert one user in room - self._is_member(room_id=self.room_id, user_id=self.other_user) - - # Test that the admin can still send shutdown - url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) - self.assertIn("new_room_id", channel.json_body) - self.assertIn("failed_to_kick_users", channel.json_body) - self.assertIn("local_aliases", channel.json_body) - - # Test that member has moved to new room - self._is_member( - room_id=channel.json_body["new_room_id"], user_id=self.other_user - ) - - self._is_purged(self.room_id) - self._has_no_members(self.room_id) - - def test_shutdown_room_block_peek(self): - """Test that a world_readable room can no longer be peeked into after - it has been shut down. - Members will be moved to a new room and will receive a message. - """ - self.event_creation_handler._block_events_without_consent_error = None - - # Enable world readable - url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - json.dumps({"history_visibility": "world_readable"}), - access_token=self.other_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that room is not purged - with self.assertRaises(AssertionError): - self._is_purged(self.room_id) - - # Assert one user in room - self._is_member(room_id=self.room_id, user_id=self.other_user) - - # Test that the admin can still send shutdown - url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) - self.assertIn("new_room_id", channel.json_body) - self.assertIn("failed_to_kick_users", channel.json_body) - self.assertIn("local_aliases", channel.json_body) - - # Test that member has moved to new room - self._is_member( - room_id=channel.json_body["new_room_id"], user_id=self.other_user - ) - - self._is_purged(self.room_id) - self._has_no_members(self.room_id) - - # Assert we can no longer peek into the room - self._assert_peek(self.room_id, expect_code=403) - - def _is_blocked(self, room_id, expect=True): - """Assert that the room is blocked or not - """ - d = self.store.is_room_blocked(room_id) - if expect: - self.assertTrue(self.get_success(d)) - else: - self.assertIsNone(self.get_success(d)) - - def _has_no_members(self, room_id): - """Assert there is now no longer anyone in the room - """ - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([], users_in_room) - - def _is_member(self, room_id, user_id): - """Test that user is member of the room - """ - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertIn(user_id, users_in_room) - - def _is_purged(self, room_id): - """Test that the following tables have been purged of all rows related to the room. - """ - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "users_in_public_rooms", - "users_who_share_private_rooms", - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "local_invites", - "room_account_data", - "room_tags", - # "state_groups", # Current impl leaves orphaned state groups around. - "state_groups_state", - ): - count = self.get_success( - self.store.db.simple_select_one_onecol( - table=table, - keyvalues={"room_id": room_id}, - retcol="COUNT(*)", - desc="test_purge_room", - ) - ) - - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) - - def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room. - """ - - url = "rooms/%s/initialSync" % (room_id,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - url = "events?timeout=0&room_id=" + room_id - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - -class PurgeRoomTestCase(unittest.HomeserverTestCase): - """Test /purge_room admin API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - def test_purge_room(self): - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # All users have to have left the room. - self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) - - url = "/_synapse/admin/v1/purge_room" - request, channel = self.make_request( - "POST", - url.encode("ascii"), - {"room_id": room_id}, - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the following tables have been purged of all rows related to the room. - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "users_in_public_rooms", - "users_who_share_private_rooms", - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "room_account_data", - "room_tags", - # "state_groups", # Current impl leaves orphaned state groups around. - "state_groups_state", - ): - count = self.get_success( - self.store.db.simple_select_one_onecol( - table=table, - keyvalues={"room_id": room_id}, - retcol="COUNT(*)", - desc="test_purge_room", - ) - ) - - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) - - -class RoomTestCase(unittest.HomeserverTestCase): - """Test /room admin API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - directory.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - # Create user - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - def test_list_rooms(self): - """Test that we can list rooms""" - # Create 3 test rooms - total_rooms = 3 - room_ids = [] - for x in range(total_rooms): - room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok - ) - room_ids.append(room_id) - - # Request the list of rooms - url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - - # Check request completed successfully - self.assertEqual(200, int(channel.code), msg=channel.json_body) - - # Check that response json body contains a "rooms" key - self.assertTrue( - "rooms" in channel.json_body, - msg="Response body does not " "contain a 'rooms' key", - ) - - # Check that 3 rooms were returned - self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body) - - # Check their room_ids match - returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]] - self.assertEqual(room_ids, returned_room_ids) - - # Check that all fields are available - for r in channel.json_body["rooms"]: - self.assertIn("name", r) - self.assertIn("canonical_alias", r) - self.assertIn("joined_members", r) - self.assertIn("joined_local_members", r) - self.assertIn("version", r) - self.assertIn("creator", r) - self.assertIn("encryption", r) - self.assertIn("federatable", r) - self.assertIn("public", r) - self.assertIn("join_rules", r) - self.assertIn("guest_access", r) - self.assertIn("history_visibility", r) - self.assertIn("state_events", r) - - # Check that the correct number of total rooms was returned - self.assertEqual(channel.json_body["total_rooms"], total_rooms) - - # Check that the offset is correct - # Should be 0 as we aren't paginating - self.assertEqual(channel.json_body["offset"], 0) - - # Check that the prev_batch parameter is not present - self.assertNotIn("prev_batch", channel.json_body) - - # We shouldn't receive a next token here as there's no further rooms to show - self.assertNotIn("next_batch", channel.json_body) - - def test_list_rooms_pagination(self): - """Test that we can get a full list of rooms through pagination""" - # Create 5 test rooms - total_rooms = 5 - room_ids = [] - for x in range(total_rooms): - room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok - ) - room_ids.append(room_id) - - # Set the name of the rooms so we get a consistent returned ordering - for idx, room_id in enumerate(room_ids): - self.helper.send_state( - room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, - ) - - # Request the list of rooms - returned_room_ids = [] - start = 0 - limit = 2 - - run_count = 0 - should_repeat = True - while should_repeat: - run_count += 1 - - url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % ( - start, - limit, - "name", - ) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"] - ) - - self.assertTrue("rooms" in channel.json_body) - for r in channel.json_body["rooms"]: - returned_room_ids.append(r["room_id"]) - - # Check that the correct number of total rooms was returned - self.assertEqual(channel.json_body["total_rooms"], total_rooms) - - # Check that the offset is correct - # We're only getting 2 rooms each page, so should be 2 * last run_count - self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1)) - - if run_count > 1: - # Check the value of prev_batch is correct - self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2)) - - if "next_batch" not in channel.json_body: - # We have reached the end of the list - should_repeat = False - else: - # Make another query with an updated start value - start = channel.json_body["next_batch"] - - # We should've queried the endpoint 3 times - self.assertEqual( - run_count, - 3, - msg="Should've queried 3 times for 5 rooms with limit 2 per query", - ) - - # Check that we received all of the room ids - self.assertEqual(room_ids, returned_room_ids) - - url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - def test_correct_room_attributes(self): - """Test the correct attributes for a room are returned""" - # Create a test room - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - test_alias = "#test:test" - test_room_name = "something" - - # Have another user join the room - user_2 = self.register_user("user4", "pass") - user_tok_2 = self.login("user4", "pass") - self.helper.join(room_id, user_2, tok=user_tok_2) - - # Create a new alias to this room - url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - {"room_id": room_id}, - access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Set this new alias as the canonical alias for this room - self.helper.send_state( - room_id, - "m.room.aliases", - {"aliases": [test_alias]}, - tok=self.admin_user_tok, - state_key="test", - ) - self.helper.send_state( - room_id, - "m.room.canonical_alias", - {"alias": test_alias}, - tok=self.admin_user_tok, - ) - - # Set a name for the room - self.helper.send_state( - room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, - ) - - # Request the list of rooms - url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check that only one room was returned - self.assertEqual(len(rooms), 1) - - # And that the value of the total_rooms key was correct - self.assertEqual(channel.json_body["total_rooms"], 1) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - # Check that all provided attributes are set - r = rooms[0] - self.assertEqual(room_id, r["room_id"]) - self.assertEqual(test_room_name, r["name"]) - self.assertEqual(test_alias, r["canonical_alias"]) - - def test_room_list_sort_order(self): - """Test room list sort ordering. alphabetical name versus number of members, - reversing the order, etc. - """ - - def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str): - # Create a new alias to this room - url = "/_matrix/client/r0/directory/room/%s" % ( - urllib.parse.quote(test_alias), - ) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - {"room_id": room_id}, - access_token=admin_user_tok, - ) - self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"] - ) - - # Set this new alias as the canonical alias for this room - self.helper.send_state( - room_id, - "m.room.aliases", - {"aliases": [test_alias]}, - tok=admin_user_tok, - state_key="test", - ) - self.helper.send_state( - room_id, - "m.room.canonical_alias", - {"alias": test_alias}, - tok=admin_user_tok, - ) - - def _order_test( - order_type: str, expected_room_list: List[str], reverse: bool = False, - ): - """Request the list of rooms in a certain order. Assert that order is what - we expect - - Args: - order_type: The type of ordering to give the server - expected_room_list: The list of room_ids in the order we expect to get - back from the server - """ - # Request the list of rooms in the given order - url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) - if reverse: - url += "&dir=b" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, channel.code, msg=channel.json_body) - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check for the correct total_rooms value - self.assertEqual(channel.json_body["total_rooms"], 3) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - # Check that rooms were returned in alphabetical order - returned_order = [r["room_id"] for r in rooms] - self.assertListEqual(expected_room_list, returned_order) # order is checked - - # Create 3 test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C - self.helper.send_state( - room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, - ) - - # Set room canonical room aliases - _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok) - _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok) - _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok) - - # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3 - user_1 = self.register_user("bob1", "pass") - user_1_tok = self.login("bob1", "pass") - self.helper.join(room_id_2, user_1, tok=user_1_tok) - - user_2 = self.register_user("bob2", "pass") - user_2_tok = self.login("bob2", "pass") - self.helper.join(room_id_3, user_2, tok=user_2_tok) - - user_3 = self.register_user("bob3", "pass") - user_3_tok = self.login("bob3", "pass") - self.helper.join(room_id_3, user_3, tok=user_3_tok) - - # Test different sort orders, with forward and reverse directions - _order_test("name", [room_id_1, room_id_2, room_id_3]) - _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True) - - _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3]) - _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True) - - _order_test("joined_members", [room_id_3, room_id_2, room_id_1]) - _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1]) - _order_test( - "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True - ) - - _order_test("version", [room_id_1, room_id_2, room_id_3]) - _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("creator", [room_id_1, room_id_2, room_id_3]) - _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("encryption", [room_id_1, room_id_2, room_id_3]) - _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("federatable", [room_id_1, room_id_2, room_id_3]) - _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("public", [room_id_1, room_id_2, room_id_3]) - # Different sort order of SQlite and PostreSQL - # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True) - - _order_test("join_rules", [room_id_1, room_id_2, room_id_3]) - _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("guest_access", [room_id_1, room_id_2, room_id_3]) - _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("history_visibility", [room_id_1, room_id_2, room_id_3]) - _order_test( - "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True - ) - - _order_test("state_events", [room_id_3, room_id_2, room_id_1]) - _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) - - def test_search_term(self): - """Test that searching for a room works correctly""" - # Create two test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - room_name_1 = "something" - room_name_2 = "else" - - # Set the name for each room - self.helper.send_state( - room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, - ) - - def _search_test( - expected_room_id: Optional[str], - search_term: str, - expected_http_code: int = 200, - ): - """Search for a room and check that the returned room's id is a match - - Args: - expected_room_id: The room_id expected to be returned by the API. Set - to None to expect zero results for the search - search_term: The term to search for room names with - expected_http_code: The expected http code for the request - """ - url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - - if expected_http_code != 200: - return - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check that the expected number of rooms were returned - expected_room_count = 1 if expected_room_id else 0 - self.assertEqual(len(rooms), expected_room_count) - self.assertEqual(channel.json_body["total_rooms"], expected_room_count) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - if expected_room_id: - # Check that the first returned room id is correct - r = rooms[0] - self.assertEqual(expected_room_id, r["room_id"]) - - # Perform search tests - _search_test(room_id_1, "something") - _search_test(room_id_1, "thing") - - _search_test(room_id_2, "else") - _search_test(room_id_2, "se") - - _search_test(None, "foo") - _search_test(None, "bar") - _search_test(None, "", expected_http_code=400) - - def test_single_room(self): - """Test that a single room can be requested correctly""" - # Create two test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - room_name_1 = "something" - room_name_2 = "else" - - # Set the name for each room - self.helper.send_state( - room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, - ) - - url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, channel.code, msg=channel.json_body) - - self.assertIn("room_id", channel.json_body) - self.assertIn("name", channel.json_body) - self.assertIn("canonical_alias", channel.json_body) - self.assertIn("joined_members", channel.json_body) - self.assertIn("joined_local_members", channel.json_body) - self.assertIn("version", channel.json_body) - self.assertIn("creator", channel.json_body) - self.assertIn("encryption", channel.json_body) - self.assertIn("federatable", channel.json_body) - self.assertIn("public", channel.json_body) - self.assertIn("join_rules", channel.json_body) - self.assertIn("guest_access", channel.json_body) - self.assertIn("history_visibility", channel.json_body) - self.assertIn("state_events", channel.json_body) - - self.assertEqual(room_id_1, channel.json_body["room_id"]) - - def test_room_members(self): - """Test that room members can be requested correctly""" - # Create two test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # Have another user join the room - user_1 = self.register_user("foo", "pass") - user_tok_1 = self.login("foo", "pass") - self.helper.join(room_id_1, user_1, tok=user_tok_1) - - # Have another user join the room - user_2 = self.register_user("bar", "pass") - user_tok_2 = self.login("bar", "pass") - self.helper.join(room_id_1, user_2, tok=user_tok_2) - self.helper.join(room_id_2, user_2, tok=user_tok_2) - - # Have another user join the room - user_3 = self.register_user("foobar", "pass") - user_tok_3 = self.login("foobar", "pass") - self.helper.join(room_id_2, user_3, tok=user_tok_3) - - url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, channel.code, msg=channel.json_body) - - self.assertCountEqual( - ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] - ) - self.assertEqual(channel.json_body["total"], 3) - - url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, channel.code, msg=channel.json_body) - - self.assertCountEqual( - ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] - ) - self.assertEqual(channel.json_body["total"], 3) - - -class JoinAliasRoomTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets, - room.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.creator = self.register_user("creator", "test") - self.creator_tok = self.login("creator", "test") - - self.second_user_id = self.register_user("second", "test") - self.second_tok = self.login("second", "test") - - self.public_room_id = self.helper.create_room_as( - self.creator, tok=self.creator_tok, is_public=True - ) - self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) - - def test_requester_is_no_admin(self): - """ - If the user is not a server admin, an error 403 is returned. - """ - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.second_tok, - ) - self.render(request) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - def test_invalid_parameter(self): - """ - If a parameter is missing, return an error - """ - body = json.dumps({"unknown_parameter": "@unknown:test"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) - - def test_local_user_does_not_exist(self): - """ - Tests that a lookup for a user that does not exist returns a 404 - """ - body = json.dumps({"user_id": "@unknown:test"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - def test_remote_user(self): - """ - Check that only local user can join rooms. - """ - body = json.dumps({"user_id": "@not:exist.bla"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "This endpoint can only be used with local users", - channel.json_body["error"], - ) - - def test_room_does_not_exist(self): - """ - Check that unknown rooms/server return error 404. - """ - body = json.dumps({"user_id": self.second_user_id}) - url = "/_synapse/admin/v1/join/!unknown:test" - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("No known servers", channel.json_body["error"]) - - def test_room_is_not_valid(self): - """ - Check that invalid room names, return an error 400. - """ - body = json.dumps({"user_id": self.second_user_id}) - url = "/_synapse/admin/v1/join/invalidroom" - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "invalidroom was not legal room ID or room alias", - channel.json_body["error"], - ) - - def test_join_public_room(self): - """ - Test joining a local user to a public room with "JoinRules.PUBLIC" - """ - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(self.public_room_id, channel.json_body["room_id"]) - - # Validate if user is a member of the room - - request, channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, - ) - self.render(request) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) - - def test_join_private_room_if_not_member(self): - """ - Test joining a local user to a private room with "JoinRules.INVITE" - when server admin is not member of this room. - """ - private_room_id = self.helper.create_room_as( - self.creator, tok=self.creator_tok, is_public=False - ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - def test_join_private_room_if_member(self): - """ - Test joining a local user to a private room with "JoinRules.INVITE", - when server admin is member of this room. - """ - private_room_id = self.helper.create_room_as( - self.creator, tok=self.creator_tok, is_public=False - ) - self.helper.invite( - room=private_room_id, - src=self.creator, - targ=self.admin_user, - tok=self.creator_tok, - ) - self.helper.join( - room=private_room_id, user=self.admin_user, tok=self.admin_user_tok - ) - - # Validate if server admin is a member of the room - - request, channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - - # Join user to room. - - url = "/_synapse/admin/v1/join/{}".format(private_room_id) - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["room_id"]) - - # Validate if user is a member of the room - - request, channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, - ) - self.render(request) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - - def test_join_private_room_if_owner(self): - """ - Test joining a local user to a private room with "JoinRules.INVITE", - when server admin is owner of this room. - """ - private_room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok, is_public=False - ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["room_id"]) - - # Validate if user is a member of the room - - request, channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, - ) - self.render(request) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import urllib.parse +from typing import List, Optional + +from mock import Mock + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import directory, events, login, room + +from tests import unittest + +"""Tests admin REST events for /rooms paths.""" + + +class ShutdownRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + """ + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + room_id, body="foo", tok=self.other_user_token, expect_code=403 + ) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert there is now no longer anyone in the room + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + """ + + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (room_id,) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + json.dumps({"history_visibility": "world_readable"}), + access_token=self.other_user_token, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert we can no longer peek into the room + self._assert_peek(room_id, expect_code=403) + + def _assert_peek(self, room_id, expect_code): + """Assert that the admin user can (or cannot) peek into the room. + """ + + url = "rooms/%s/initialSync" % (room_id,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + url = "events?timeout=0&room_id=" + room_id + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + +class DeleteRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + + request, channel = self.make_request( + "POST", self.url, json.dumps({}), access_token=self.other_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return error 404. + """ + url = "/_synapse/admin/v1/rooms/!unknown:test/delete" + + request, channel = self.make_request( + "POST", url, json.dumps({}), access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_room_is_not_valid(self): + """ + Check that invalid room names, return an error 400. + """ + url = "/_synapse/admin/v1/rooms/invalidroom/delete" + + request, channel = self.make_request( + "POST", url, json.dumps({}), access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "invalidroom is not a legal room ID", channel.json_body["error"], + ) + + def test_new_room_user_does_not_exist(self): + """ + Tests that the user ID must be from local server but it does not have to exist. + """ + body = json.dumps({"new_room_user_id": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("kicked_users", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + def test_new_room_user_is_not_local(self): + """ + Check that only local users can create new room to move members. + """ + body = json.dumps({"new_room_user_id": "@not:exist.bla"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "User must be our own: @not:exist.bla", channel.json_body["error"], + ) + + def test_block_is_not_bool(self): + """ + If parameter `block` is not boolean, return an error + """ + body = json.dumps({"block": "NotBool"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_purge_is_not_bool(self): + """ + If parameter `purge` is not boolean, return an error + """ + body = json.dumps({"purge": "NotBool"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_purge_room_and_block(self): + """Test to purge a room and block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": True, "purge": True}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_purge_room_and_not_block(self): + """Test to purge a room and do not block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": False, "purge": True}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + + def test_block_room_and_not_purge(self): + """Test to block a room without purging it. + Members will not be moved to a new room and will not receive a message. + The room will not be purged. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": False, "purge": False}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(self.room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 + ) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["new_room_id"], user_id=self.other_user + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + json.dumps({"history_visibility": "world_readable"}), + access_token=self.other_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["new_room_id"], user_id=self.other_user + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + # Assert we can no longer peek into the room + self._assert_peek(self.room_id, expect_code=403) + + def _is_blocked(self, room_id, expect=True): + """Assert that the room is blocked or not + """ + d = self.store.is_room_blocked(room_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertIsNone(self.get_success(d)) + + def _has_no_members(self, room_id): + """Assert there is now no longer anyone in the room + """ + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def _is_member(self, room_id, user_id): + """Test that user is member of the room + """ + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertIn(user_id, users_in_room) + + def _is_purged(self, room_id): + """Test that the following tables have been purged of all rows related to the room. + """ + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + # "state_groups", # Current impl leaves orphaned state groups around. + "state_groups_state", + ): + count = self.get_success( + self.store.db.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + + def _assert_peek(self, room_id, expect_code): + """Assert that the admin user can (or cannot) peek into the room. + """ + + url = "rooms/%s/initialSync" % (room_id,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + url = "events?timeout=0&room_id=" + room_id + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + +class PurgeRoomTestCase(unittest.HomeserverTestCase): + """Test /purge_room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_purge_room(self): + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # All users have to have left the room. + self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/purge_room" + request, channel = self.make_request( + "POST", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the following tables have been purged of all rows related to the room. + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "room_account_data", + "room_tags", + # "state_groups", # Current impl leaves orphaned state groups around. + "state_groups_state", + ): + count = self.get_success( + self.store.db.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + + +class RoomTestCase(unittest.HomeserverTestCase): + """Test /room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_list_rooms(self): + """Test that we can list rooms""" + # Create 3 test rooms + total_rooms = 3 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + + # Check request completed successfully + self.assertEqual(200, int(channel.code), msg=channel.json_body) + + # Check that response json body contains a "rooms" key + self.assertTrue( + "rooms" in channel.json_body, + msg="Response body does not " "contain a 'rooms' key", + ) + + # Check that 3 rooms were returned + self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body) + + # Check their room_ids match + returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]] + self.assertEqual(room_ids, returned_room_ids) + + # Check that all fields are available + for r in channel.json_body["rooms"]: + self.assertIn("name", r) + self.assertIn("canonical_alias", r) + self.assertIn("joined_members", r) + self.assertIn("joined_local_members", r) + self.assertIn("version", r) + self.assertIn("creator", r) + self.assertIn("encryption", r) + self.assertIn("federatable", r) + self.assertIn("public", r) + self.assertIn("join_rules", r) + self.assertIn("guest_access", r) + self.assertIn("history_visibility", r) + self.assertIn("state_events", r) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # Should be 0 as we aren't paginating + self.assertEqual(channel.json_body["offset"], 0) + + # Check that the prev_batch parameter is not present + self.assertNotIn("prev_batch", channel.json_body) + + # We shouldn't receive a next token here as there's no further rooms to show + self.assertNotIn("next_batch", channel.json_body) + + def test_list_rooms_pagination(self): + """Test that we can get a full list of rooms through pagination""" + # Create 5 test rooms + total_rooms = 5 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Set the name of the rooms so we get a consistent returned ordering + for idx, room_id in enumerate(room_ids): + self.helper.send_state( + room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + returned_room_ids = [] + start = 0 + limit = 2 + + run_count = 0 + should_repeat = True + while should_repeat: + run_count += 1 + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % ( + start, + limit, + "name", + ) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + self.assertTrue("rooms" in channel.json_body) + for r in channel.json_body["rooms"]: + returned_room_ids.append(r["room_id"]) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # We're only getting 2 rooms each page, so should be 2 * last run_count + self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1)) + + if run_count > 1: + # Check the value of prev_batch is correct + self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2)) + + if "next_batch" not in channel.json_body: + # We have reached the end of the list + should_repeat = False + else: + # Make another query with an updated start value + start = channel.json_body["next_batch"] + + # We should've queried the endpoint 3 times + self.assertEqual( + run_count, + 3, + msg="Should've queried 3 times for 5 rooms with limit 2 per query", + ) + + # Check that we received all of the room ids + self.assertEqual(room_ids, returned_room_ids) + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def test_correct_room_attributes(self): + """Test the correct attributes for a room are returned""" + # Create a test room + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + test_alias = "#test:test" + test_room_name = "something" + + # Have another user join the room + user_2 = self.register_user("user4", "pass") + user_tok_2 = self.login("user4", "pass") + self.helper.join(room_id, user_2, tok=user_tok_2) + + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=self.admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=self.admin_user_tok, + ) + + # Set a name for the room + self.helper.send_state( + room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that only one room was returned + self.assertEqual(len(rooms), 1) + + # And that the value of the total_rooms key was correct + self.assertEqual(channel.json_body["total_rooms"], 1) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that all provided attributes are set + r = rooms[0] + self.assertEqual(room_id, r["room_id"]) + self.assertEqual(test_room_name, r["name"]) + self.assertEqual(test_alias, r["canonical_alias"]) + + def test_room_list_sort_order(self): + """Test room list sort ordering. alphabetical name versus number of members, + reversing the order, etc. + """ + + def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str): + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % ( + urllib.parse.quote(test_alias), + ) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=admin_user_tok, + ) + + def _order_test( + order_type: str, expected_room_list: List[str], reverse: bool = False, + ): + """Request the list of rooms in a certain order. Assert that order is what + we expect + + Args: + order_type: The type of ordering to give the server + expected_room_list: The list of room_ids in the order we expect to get + back from the server + """ + # Request the list of rooms in the given order + url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) + if reverse: + url += "&dir=b" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check for the correct total_rooms value + self.assertEqual(channel.json_body["total_rooms"], 3) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that rooms were returned in alphabetical order + returned_order = [r["room_id"] for r in rooms] + self.assertListEqual(expected_room_list, returned_order) # order is checked + + # Create 3 test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C + self.helper.send_state( + room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, + ) + + # Set room canonical room aliases + _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok) + _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok) + _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok) + + # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3 + user_1 = self.register_user("bob1", "pass") + user_1_tok = self.login("bob1", "pass") + self.helper.join(room_id_2, user_1, tok=user_1_tok) + + user_2 = self.register_user("bob2", "pass") + user_2_tok = self.login("bob2", "pass") + self.helper.join(room_id_3, user_2, tok=user_2_tok) + + user_3 = self.register_user("bob3", "pass") + user_3_tok = self.login("bob3", "pass") + self.helper.join(room_id_3, user_3, tok=user_3_tok) + + # Test different sort orders, with forward and reverse directions + _order_test("name", [room_id_1, room_id_2, room_id_3]) + _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3]) + _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("joined_members", [room_id_3, room_id_2, room_id_1]) + _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1]) + _order_test( + "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True + ) + + _order_test("version", [room_id_1, room_id_2, room_id_3]) + _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("creator", [room_id_1, room_id_2, room_id_3]) + _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("encryption", [room_id_1, room_id_2, room_id_3]) + _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("federatable", [room_id_1, room_id_2, room_id_3]) + _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("public", [room_id_1, room_id_2, room_id_3]) + # Different sort order of SQlite and PostreSQL + # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("join_rules", [room_id_1, room_id_2, room_id_3]) + _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("guest_access", [room_id_1, room_id_2, room_id_3]) + _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("history_visibility", [room_id_1, room_id_2, room_id_3]) + _order_test( + "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True + ) + + _order_test("state_events", [room_id_3, room_id_2, room_id_1]) + _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) + + def test_search_term(self): + """Test that searching for a room works correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + def _search_test( + expected_room_id: Optional[str], + search_term: str, + expected_http_code: int = 200, + ): + """Search for a room and check that the returned room's id is a match + + Args: + expected_room_id: The room_id expected to be returned by the API. Set + to None to expect zero results for the search + search_term: The term to search for room names with + expected_http_code: The expected http code for the request + """ + url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) + + if expected_http_code != 200: + return + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that the expected number of rooms were returned + expected_room_count = 1 if expected_room_id else 0 + self.assertEqual(len(rooms), expected_room_count) + self.assertEqual(channel.json_body["total_rooms"], expected_room_count) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + if expected_room_id: + # Check that the first returned room id is correct + r = rooms[0] + self.assertEqual(expected_room_id, r["room_id"]) + + # Perform search tests + _search_test(room_id_1, "something") + _search_test(room_id_1, "thing") + + _search_test(room_id_2, "else") + _search_test(room_id_2, "se") + + _search_test(None, "foo") + _search_test(None, "bar") + _search_test(None, "", expected_http_code=400) + + def test_single_room(self): + """Test that a single room can be requested correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertIn("room_id", channel.json_body) + self.assertIn("name", channel.json_body) + self.assertIn("canonical_alias", channel.json_body) + self.assertIn("joined_members", channel.json_body) + self.assertIn("joined_local_members", channel.json_body) + self.assertIn("version", channel.json_body) + self.assertIn("creator", channel.json_body) + self.assertIn("encryption", channel.json_body) + self.assertIn("federatable", channel.json_body) + self.assertIn("public", channel.json_body) + self.assertIn("join_rules", channel.json_body) + self.assertIn("guest_access", channel.json_body) + self.assertIn("history_visibility", channel.json_body) + self.assertIn("state_events", channel.json_body) + + self.assertEqual(room_id_1, channel.json_body["room_id"]) + + def test_room_members(self): + """Test that room members can be requested correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Have another user join the room + user_1 = self.register_user("foo", "pass") + user_tok_1 = self.login("foo", "pass") + self.helper.join(room_id_1, user_1, tok=user_tok_1) + + # Have another user join the room + user_2 = self.register_user("bar", "pass") + user_tok_2 = self.login("bar", "pass") + self.helper.join(room_id_1, user_2, tok=user_tok_2) + self.helper.join(room_id_2, user_2, tok=user_tok_2) + + # Have another user join the room + user_3 = self.register_user("foobar", "pass") + user_tok_3 = self.login("foobar", "pass") + self.helper.join(room_id_2, user_3, tok=user_tok_3) + + url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertCountEqual( + ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] + ) + self.assertEqual(channel.json_body["total"], 3) + + url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertCountEqual( + ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] + ) + self.assertEqual(channel.json_body["total"], 3) + + +class JoinAliasRoomTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.public_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.second_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If a parameter is missing, return an error + """ + body = json.dumps({"unknown_parameter": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + def test_local_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + body = json.dumps({"user_id": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_remote_user(self): + """ + Check that only local user can join rooms. + """ + body = json.dumps({"user_id": "@not:exist.bla"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "This endpoint can only be used with local users", + channel.json_body["error"], + ) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return error 404. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/!unknown:test" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No known servers", channel.json_body["error"]) + + def test_room_is_not_valid(self): + """ + Check that invalid room names, return an error 400. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/invalidroom" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "invalidroom was not legal room ID or room alias", + channel.json_body["error"], + ) + + def test_join_public_room(self): + """ + Test joining a local user to a public room with "JoinRules.PUBLIC" + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_not_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE" + when server admin is not member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_join_private_room_if_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + self.helper.invite( + room=private_room_id, + src=self.creator, + targ=self.admin_user, + tok=self.creator_tok, + ) + self.helper.join( + room=private_room_id, user=self.admin_user, tok=self.admin_user_tok + ) + + # Validate if server admin is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + # Join user to room. + + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_owner(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is owner of this room. + """ + private_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 22d734e763..7f8252330a 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -143,6 +143,26 @@ class RestHelper(object): return channel.json_body + def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + + path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id) + if tok: + path = path + "?access_token=%s" % tok + + request, channel = make_request( + self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8") + ) + render(request, self.resource, self.hs.get_reactor()) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + return channel.json_body + def _read_write_state( self, room_id: str, diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index fa3a3ec1bd..a31e44c97e 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -16,9 +16,9 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v2_alpha import read_marker, sync from tests import unittest from tests.server import TimedOutException @@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase): "GET", sync_url % (access_token, next_batch) ) self.assertRaises(TimedOutException, self.render, request) + + +class UnreadMessagesTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + read_marker.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to check the unread counts). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll check unread counts for. + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user (used to send events to the room). + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Change the power levels of the room so that the second user can send state + # events. + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + { + "users": {self.user_id: 100, self.user2: 100}, + "users_default": 0, + "events": { + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.history_visibility": 100, + "m.room.canonical_alias": 50, + "m.room.avatar": 50, + "m.room.tombstone": 100, + "m.room.server_acl": 100, + "m.room.encryption": 100, + }, + "events_default": 0, + "state_default": 50, + "ban": 50, + "kick": 50, + "redact": 50, + "invite": 0, + }, + tok=self.tok, + ) + + def test_unread_counts(self): + """Tests that /sync returns the right value for the unread count (MSC2654).""" + + # Check that our own messages don't increase the unread count. + self.helper.send(self.room_id, "hello", tok=self.tok) + self._check_unread_count(0) + + # Join the new user and check that this doesn't increase the unread count. + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + self._check_unread_count(0) + + # Check that the new user sending a message increases our unread count. + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + request, channel = self.make_request( + "POST", + "/rooms/%s/read_markers" % self.room_id, + body, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that room name changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, + ) + self._check_unread_count(1) + + # Check that room topic changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, + ) + self._check_unread_count(2) + + # Check that encrypted messages increase the unread counter. + self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) + self._check_unread_count(3) + + # Check that custom events with a body increase the unread counter. + self.helper.send_event( + self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that edits don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": "hello", + "msgtype": "m.text", + "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + }, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that notices don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"body": "hello", "msgtype": "m.notice"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that tombstone events changes increase the unread counter. + self.helper.send_state( + self.room_id, + EventTypes.Tombstone, + {"replacement_room": "!someroom:test"}, + tok=self.tok2, + ) + self._check_unread_count(5) + + def _check_unread_count(self, expected_count: True): + """Syncs and compares the unread count with the expected value.""" + + request, channel = self.make_request( + "GET", self.url % self.next_batch, access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.json_body) + + room_entry = channel.json_body["rooms"]["join"][self.room_id] + self.assertEqual( + room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, + ) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 99eb477149..6850c666be 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect an outgoing GET request for the given key """ - def get_json(destination, path, ignore_backoff=False, **kwargs): + async def get_json(destination, path, ignore_backoff=False, **kwargs): self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) @@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): # wire up outbound POST /key/v2/query requests from hs2 so that they # will be forwarded to hs1 - def post_json(destination, path, data): + async def post_json(destination, path, data): self.assertEqual(destination, self.hs.hostname) self.assertEqual( path, "/_matrix/key/v2/query", diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 66fa5978b2..f4f3e56777 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -26,6 +26,7 @@ import attr from parameterized import parameterized_class from PIL import Image as Image +from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.logging.context import make_deferred_yieldable @@ -77,7 +78,9 @@ class MediaStorageTests(unittest.HomeserverTestCase): # This uses a real blocking threadpool so we have to wait for it to be # actually done :/ - x = self.media_storage.ensure_media_is_in_local_cache(file_info) + x = defer.ensureDeferred( + self.media_storage.ensure_media_is_in_local_cache(file_info) + ) # Hotloop until the threadpool does its job... self.wait_on_thread(x) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 2826211f32..74765a582b 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -12,8 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import os +import re + +from mock import patch import attr @@ -131,7 +134,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.reactor.nameResolver = Resolver() def test_cache_returns_correct_type(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] request, channel = self.make_request( "GET", "url_preview?url=http://matrix.org", shorthand=False @@ -187,7 +190,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) def test_non_ascii_preview_httpequiv(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( b"<html><head>" @@ -221,7 +224,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") def test_non_ascii_preview_content_type(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( b"<html><head>" @@ -254,7 +257,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") def test_overlong_title(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( b"<html><head>" @@ -292,7 +295,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ IP addresses can be previewed directly. """ - self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] + self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] request, channel = self.make_request( "GET", "url_preview?url=http://example.com", shorthand=False @@ -439,7 +442,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Hardcode the URL resolving to the IP we want. self.lookups["example.com"] = [ (IPv4Address, "1.1.1.2"), - (IPv4Address, "8.8.8.8"), + (IPv4Address, "10.1.2.3"), ] request, channel = self.make_request( @@ -518,7 +521,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ Accept-Language header is sent to the remote server """ - self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] + self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] # Build and make a request to the server request, channel = self.make_request( @@ -562,3 +565,126 @@ class URLPreviewTests(unittest.HomeserverTestCase): ), server.data, ) + + def test_oembed_photo(self): + """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL.""" + # Route the HTTP version to an HTTP endpoint so that the tests work. + with patch.dict( + "synapse.rest.media.v1.preview_url_resource._oembed_patterns", + { + re.compile( + r"http://twitter\.com/.+/status/.+" + ): "http://publish.twitter.com/oembed", + }, + clear=True, + ): + + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] + self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] + + result = { + "version": "1.0", + "type": "photo", + "url": "http://cdn.twitter.com/matrixdotorg", + } + oembed_content = json.dumps(result).encode("utf-8") + + end_content = ( + b"<html><head>" + b"<title>Some Title</title>" + b'<meta property="og:description" content="hi" />' + b"</head></html>" + ) + + request, channel = self.make_request( + "GET", + "url_preview?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: application/json; charset="utf8"\r\n\r\n' + ) + % (len(oembed_content),) + + oembed_content + ) + + self.pump() + + client = self.reactor.tcpClients[1][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "Some Title", "og:description": "hi"} + ) + + def test_oembed_rich(self): + """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" + # Route the HTTP version to an HTTP endpoint so that the tests work. + with patch.dict( + "synapse.rest.media.v1.preview_url_resource._oembed_patterns", + { + re.compile( + r"http://twitter\.com/.+/status/.+" + ): "http://publish.twitter.com/oembed", + }, + clear=True, + ): + + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] + + result = { + "version": "1.0", + "type": "rich", + "html": "<div>Content Preview</div>", + } + end_content = json.dumps(result).encode("utf-8") + + request, channel = self.make_request( + "GET", + "url_preview?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: application/json; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, + {"og:title": None, "og:description": "Content Preview"}, + ) diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 38f9b423ef..f2955a9c69 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -14,6 +14,7 @@ # limitations under the License. import itertools +from typing import List import attr @@ -432,7 +433,7 @@ class StateTestCase(unittest.TestCase): state_res_store=TestStateResolutionStore(event_map), ) - state_before = self.successResultOf(state_d) + state_before = self.successResultOf(defer.ensureDeferred(state_d)) state_after = dict(state_before) if fake_event.state_key is not None: @@ -581,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase): state_res_store=TestStateResolutionStore(self.event_map), ) - state = self.successResultOf(state_d) + state = self.successResultOf(defer.ensureDeferred(state_d)) self.assert_dict(self.expected_combined_state, state) @@ -608,9 +609,11 @@ class TestStateResolutionStore(object): Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. """ - return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} + return defer.succeed( + {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} + ) - def _get_auth_chain(self, event_ids): + def _get_auth_chain(self, event_ids: List[str]) -> List[str]: """Gets the full auth chain for a set of events (including rejected events). @@ -622,10 +625,10 @@ class TestStateResolutionStore(object): presence of rejected events Args: - event_ids (list): The event IDs of the events to fetch the auth + event_ids: The event IDs of the events to fetch the auth chain for. Must be state events. Returns: - Deferred[list[str]]: List of event IDs of the auth chain. + List of event IDs of the auth chain. """ # Simple DFS for auth chain @@ -648,4 +651,4 @@ class TestStateResolutionStore(object): chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) - return set(chains[0]).union(*chains[1:]) - common + return defer.succeed(set(chains[0]).union(*chains[1:]) - common) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index b45bc9c115..2b1580feeb 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_http( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_http( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_email(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_email( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_email( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks @@ -72,8 +76,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.internal_metadata.stream_ordering = stream event.depth = stream - yield self.store.add_push_actions_to_staging( - event.event_id, {user_id: action} + yield defer.ensureDeferred( + self.store.add_push_actions_to_staging( + event.event_id, {user_id: action} + ) ) yield self.store.db.runInteraction( "", diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index b9fafaa1a6..a6012c973d 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from synapse.rest.client.v1 import room from tests.unittest import HomeserverTestCase @@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase): event = self.successResultOf(event) # Purge everything before this topological token - purge = storage.purge_events.purge_history(self.room_id, event, True) + purge = defer.ensureDeferred( + storage.purge_events.purge_history(self.room_id, event, True) + ) self.pump() self.assertEqual(self.successResultOf(purge), None) @@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase): ) # Purge everything before this topological token - purge = storage.purge_history(self.room_id, event, True) + purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) self.pump() f = self.failureResultOf(purge) self.assertIn("greater than forward", f.value.args[0]) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index db3667dc43..0f0e1cd09b 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def build(self, prev_event_ids): - built_event = yield self._base_builder.build(prev_event_ids) + built_event = yield defer.ensureDeferred( + self._base_builder.build(prev_event_ids) + ) built_event._event_id = self._event_id built_event._dict["event_id"] = self._event_id diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index b1dceb2918..d07b985a8e 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase): self.alias = RoomAlias.from_string("#a-room-name:test") self.u_creator = UserID.from_string("@creator:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id=self.u_creator.to_string(), - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id=self.u_creator.to_string(), + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks @@ -88,17 +90,21 @@ class RoomEventsStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abcde:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.storage.persistence.persist_event( - self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + yield defer.ensureDeferred( + self.storage.persistence.persist_event( + self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + ) ) @defer.inlineCallbacks @@ -109,7 +115,9 @@ class RoomEventsStoreTestCase(unittest.TestCase): etype=EventTypes.Name, name=name, content={"name": name}, depth=1 ) - state = yield self.store.get_current_state(room_id=self.room.to_string()) + state = yield defer.ensureDeferred( + self.store.get_current_state(room_id=self.room.to_string()) + ) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( @@ -125,7 +133,9 @@ class RoomEventsStoreTestCase(unittest.TestCase): etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 ) - state = yield self.store.get_current_state(room_id=self.room.to_string()) + state = yield defer.ensureDeferred( + self.store.get_current_state(room_id=self.room.to_string()) + ) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5dd46005e6..f282921538 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -118,18 +118,22 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): def test_get_joined_users_from_context(self): room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) - bob_event = event_injection.inject_member_event( - self.hs, room, self.u_bob, Membership.JOIN + bob_event = self.get_success( + event_injection.inject_member_event( + self.hs, room, self.u_bob, Membership.JOIN + ) ) # first, create a regular event - event, context = event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[bob_event.event_id], - type="m.test.1", - content={}, + event, context = self.get_success( + event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[bob_event.event_id], + type="m.test.1", + content={}, + ) ) users = self.get_success( @@ -140,22 +144,26 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # Regression test for #7376: create a state event whose key matches bob's # user_id, but which is *not* a membership event, and persist that; then check # that `get_joined_users_from_context` returns the correct users for the next event. - non_member_event = event_injection.inject_event( - self.hs, - room_id=room, - sender=self.u_bob, - prev_event_ids=[bob_event.event_id], - type="m.test.2", - state_key=self.u_bob, - content={}, + non_member_event = self.get_success( + event_injection.inject_event( + self.hs, + room_id=room, + sender=self.u_bob, + prev_event_ids=[bob_event.event_id], + type="m.test.2", + state_key=self.u_bob, + content={}, + ) ) - event, context = event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[non_member_event.event_id], - type="m.test.3", - content={}, + event, context = self.get_success( + event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[non_member_event.event_id], + type="m.test.3", + content={}, + ) ) users = self.get_success( self.store.get_joined_users_from_context(event, context) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 0b88308ff4..8bd12fa847 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks @@ -64,11 +66,13 @@ class StateStoreTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @@ -87,8 +91,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups_ids( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] @@ -106,8 +110,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -148,7 +152,9 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.storage.state.get_state_for_event(e5.event_id) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event(e5.event_id) + ) self.assertIsNotNone(e4) @@ -164,22 +170,28 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + ) ) self.assertStateMapEqual( @@ -188,12 +200,14 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: {self.u_alice.to_string()}}, - include_others=True, - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: {self.u_alice.to_string()}}, + include_others=True, + ), + ) ) self.assertStateMapEqual( @@ -206,11 +220,13 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: set()}, include_others=True + ), + ) ) self.assertStateMapEqual( @@ -222,8 +238,8 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.storage.state.get_state_groups_ids( - room_id, [e5.event_id] + group_ids = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) ) group = list(group_ids.keys())[0] diff --git a/tests/test_federation.py b/tests/test_federation.py index 87a16d7d7a..c2f12c2741 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -95,7 +95,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): prev_events that said event references. """ - def post_json(destination, path, data, headers=None, timeout=0): + async def post_json(destination, path, data, headers=None, timeout=0): # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} diff --git a/tests/test_server.py b/tests/test_server.py index 030f58cbdc..073b2362cc 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -12,26 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import re -from io import StringIO from twisted.internet.defer import Deferred -from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, RedirectException, SynapseError from synapse.config.server import parse_listener_def from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource -from synapse.http.site import SynapseSite, logger +from synapse.http.site import SynapseSite from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock from tests import unittest from tests.server import ( - FakeTransport, ThreadedMemoryReactorClock, make_request, render, @@ -199,10 +193,10 @@ class OptionsResourceTests(unittest.TestCase): return channel def test_unknown_options_request(self): - """An OPTIONS requests to an unknown URL still returns 200 OK.""" + """An OPTIONS requests to an unknown URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/foo/") - self.assertEqual(channel.result["code"], b"200") - self.assertEqual(channel.result["body"], b"{}") + self.assertEqual(channel.result["code"], b"204") + self.assertNotIn("body", channel.result) # Ensure the correct CORS headers have been added self.assertTrue( @@ -219,10 +213,10 @@ class OptionsResourceTests(unittest.TestCase): ) def test_known_options_request(self): - """An OPTIONS requests to an known URL still returns 200 OK.""" + """An OPTIONS requests to an known URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/res/") - self.assertEqual(channel.result["code"], b"200") - self.assertEqual(channel.result["body"], b"{}") + self.assertEqual(channel.result["code"], b"204") + self.assertNotIn("body", channel.result) # Ensure the correct CORS headers have been added self.assertTrue( @@ -318,54 +312,3 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.assertEqual(location_headers, [b"/no/over/there"]) cookies_headers = [v for k, v in headers if k == b"Set-Cookie"] self.assertEqual(cookies_headers, [b"session=yespls"]) - - -class SiteTestCase(unittest.HomeserverTestCase): - def test_lose_connection(self): - """ - We log the URI correctly redacted when we lose the connection. - """ - - class HangingResource(Resource): - """ - A Resource that strategically hangs, as if it were processing an - answer. - """ - - def render(self, request): - return NOT_DONE_YET - - # Set up a logging handler that we can inspect afterwards - output = StringIO() - handler = logging.StreamHandler(output) - logger.addHandler(handler) - old_level = logger.level - logger.setLevel(10) - self.addCleanup(logger.setLevel, old_level) - self.addCleanup(logger.removeHandler, handler) - - # Make a resource and a Site, the resource will hang and allow us to - # time out the request while it's 'processing' - base_resource = Resource() - base_resource.putChild(b"", HangingResource()) - site = SynapseSite( - "test", "site_tag", self.hs.config.listeners[0], base_resource, "1.0" - ) - - server = site.buildProtocol(None) - client = AccumulatingProtocol() - client.makeConnection(FakeTransport(server, self.reactor)) - server.makeConnection(FakeTransport(client, self.reactor)) - - # Send a request with an access token that will get redacted - server.dataReceived(b"GET /?access_token=bar HTTP/1.0\r\n\r\n") - self.pump() - - # Lose the connection - e = Failure(Exception("Failed123")) - server.connectionLost(e) - handler.flush() - - # Our access token is redacted and the failure reason is logged. - self.assertIn("/?access_token=<redacted>", output.getvalue()) - self.assertIn("Failed123", output.getvalue()) diff --git a/tests/test_state.py b/tests/test_state.py index 66f22f6813..b5c3667d2a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -97,17 +97,19 @@ class StateGroupStore(object): self._group_to_state[state_group] = dict(current_state_ids) - return state_group + return defer.succeed(state_group) def get_events(self, event_ids, **kwargs): - return { - e_id: self._event_id_to_event[e_id] - for e_id in event_ids - if e_id in self._event_id_to_event - } + return defer.succeed( + { + e_id: self._event_id_to_event[e_id] + for e_id in event_ids + if e_id in self._event_id_to_event + } + ) def get_state_group_delta(self, name): - return None, None + return defer.succeed((None, None)) def register_events(self, events): for e in events: @@ -120,7 +122,7 @@ class StateGroupStore(object): self._event_to_state_group[event_id] = state_group def get_room_version_id(self, room_id): - return RoomVersions.V1.identifier + return defer.succeed(RoomVersions.V1.identifier) class DictObj(dict): @@ -202,14 +204,16 @@ class StateTestCase(unittest.TestCase): context_store = {} # type: dict[str, EventContext] for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertEqual(2, len(prev_state_ids)) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -253,7 +259,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -310,7 +318,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_e = context_store["E"] - prev_state_ids = yield ctx_e.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids()) self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) @@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -383,7 +393,7 @@ class StateTestCase(unittest.TestCase): ctx_b = context_store["B"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values())) self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) @@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context(event, old_state=old_state) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event, old_state=old_state) + ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state), current_state_ids.values() ) @@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context(event, old_state=old_state) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event, old_state=old_state) + ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state + [event]), current_state_ids.values() ) @@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = self.store.store_state_group( + group_name = yield self.store.store_state_group( prev_event_id, event.room_id, None, @@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id, group_name) - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred(self.state.compute_event_context(event)) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual( {e.event_id for e in old_state}, set(current_state_ids.values()) @@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = self.store.store_state_group( + group_name = yield self.store.store_state_group( prev_event_id, event.room_id, None, @@ -503,9 +517,9 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id, group_name) - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred(self.state.compute_event_context(event)) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values())) @@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) @@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) + @defer.inlineCallbacks def _get_context( self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): - sg1 = self.store.store_state_group( + sg1 = yield self.store.store_state_group( prev_event_id_1, event.room_id, None, @@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id_1, sg1) - sg2 = self.store.store_state_group( + sg2 = yield self.store.store_state_group( prev_event_id_2, event.room_id, None, @@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id_2, sg2) - return self.state.compute_event_context(event) + result = yield defer.ensureDeferred(self.state.compute_event_context(event)) + return result diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 7b345b03bb..508aeba078 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -17,7 +17,7 @@ """ Utilities for running the unit tests """ -from typing import Awaitable, TypeVar +from typing import Any, Awaitable, TypeVar TV = TypeVar("TV") @@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: # if next didn't raise, the awaitable hasn't completed. raise Exception("awaitable has not yet completed") + + +async def make_awaitable(result: Any): + """Create an awaitable that just returns a result.""" + return result diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 43297b530c..8522c6fc09 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -22,14 +22,12 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import Collection -from tests.test_utils import get_awaitable_result - """ Utility functions for poking events into the storage of the server under test. """ -def inject_member_event( +async def inject_member_event( hs: synapse.server.HomeServer, room_id: str, sender: str, @@ -46,7 +44,7 @@ def inject_member_event( if extra_content: content.update(extra_content) - return inject_event( + return await inject_event( hs, room_id=room_id, type=EventTypes.Member, @@ -57,7 +55,7 @@ def inject_member_event( ) -def inject_event( +async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[Collection[str]] = None, @@ -72,37 +70,27 @@ def inject_event( prev_event_ids: prev_events for the event. If not specified, will be looked up kwargs: fields for the event to be created """ - test_reactor = hs.get_reactor() - - event, context = create_event(hs, room_version, prev_event_ids, **kwargs) + event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) - d = hs.get_storage().persistence.persist_event(event, context) - test_reactor.advance(0) - get_awaitable_result(d) + await hs.get_storage().persistence.persist_event(event, context) return event -def create_event( +async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[Collection[str]] = None, **kwargs ) -> Tuple[EventBase, EventContext]: - test_reactor = hs.get_reactor() - if room_version is None: - d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) - test_reactor.advance(0) - room_version = get_awaitable_result(d) + room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"]) builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs ) - d = hs.get_event_creation_handler().create_new_client_event( + event, context = await hs.get_event_creation_handler().create_new_client_event( builder, prev_event_ids=prev_event_ids ) - test_reactor.advance(0) - event, context = get_awaitable_result(d) return event, context diff --git a/tests/test_visibility.py b/tests/test_visibility.py index f7381b2885..531a9b9118 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.store = self.hs.get_datastore() self.storage = self.hs.get_storage() - yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") + yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @defer.inlineCallbacks def test_filtering(self): @@ -53,7 +53,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # # before we do that, we persist some other events to act as state. - self.inject_visibility("@admin:hs", "joined") + yield self.inject_visibility("@admin:hs", "joined") for i in range(0, 10): yield self.inject_room_member("@resident%i:hs" % i) @@ -64,8 +64,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): evt = yield self.inject_room_member(user, extra_content={"a": "b"}) events_to_filter.append(evt) - filtered = yield filter_events_for_server( - self.storage, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(self.storage, "test_server", events_to_filter) ) # the result should be 5 redacted events, and 5 unredacted events. @@ -102,8 +102,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") # ... and the filtering happens. - filtered = yield filter_events_for_server( - self.storage, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(self.storage, "test_server", events_to_filter) ) for i in range(0, len(events_to_filter)): @@ -137,10 +137,12 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) + ) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) ) - yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -158,11 +160,13 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -179,11 +183,13 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -265,8 +271,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): storage.main = test_store storage.state = test_store - filtered = yield filter_events_for_server( - test_store, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(test_store, "test_server", events_to_filter) ) logger.info("Filtering took %f seconds", time.time() - start) diff --git a/tests/unittest.py b/tests/unittest.py index 3175a3fa02..68d2586efd 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -603,7 +603,9 @@ class HomeserverTestCase(TestCase): user: MXID of the user to inject the membership for. membership: The membership type. """ - event_injection.inject_member_event(self.hs, room, user, membership) + self.get_success( + event_injection.inject_member_event(self.hs, room, user, membership) + ) class FederatingHomeserverTestCase(HomeserverTestCase): diff --git a/tests/utils.py b/tests/utils.py index 4d17355a5c..b33b6860d4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -638,14 +638,8 @@ class DeferredMockCallable(object): ) -@defer.inlineCallbacks -def create_room(hs, room_id, creator_id): +async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room - - Args: - hs - room_id (str) - creator_id (str) """ persistence_store = hs.get_storage().persistence @@ -653,7 +647,7 @@ def create_room(hs, room_id, creator_id): event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() - yield store.store_room( + await store.store_room( room_id=room_id, room_creator_user_id=creator_id, is_public=False, @@ -671,6 +665,6 @@ def create_room(hs, room_id, creator_id): }, ) - event, context = yield event_creation_handler.create_new_client_event(builder) + event, context = await event_creation_handler.create_new_client_event(builder) - yield persistence_store.persist_event(event, context) + await persistence_store.persist_event(event, context) |