diff options
Diffstat (limited to 'tests')
102 files changed, 5529 insertions, 1481 deletions
diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index 5527e278db..d66aeb00eb 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -1,6 +1,6 @@ import synapse from synapse.app.phone_stats_home import start_phone_stats_home -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest from tests.unittest import HomeserverTestCase diff --git a/tests/config/test_base.py b/tests/config/test_base.py index 84ae3b88ae..baa5313fb3 100644 --- a/tests/config/test_base.py +++ b/tests/config/test_base.py @@ -30,7 +30,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # contain template files with tempfile.TemporaryDirectory() as tmp_dir: # Attempt to load an HTML template from our custom template directory - template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0] + template = self.hs.config.read_templates(["sso_error.html"], (tmp_dir,))[0] # If no errors, we should've gotten the default template instead @@ -60,7 +60,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Attempt to load the template from our custom template directory template = ( - self.hs.config.read_templates([template_filename], tmp_dir) + self.hs.config.read_templates([template_filename], (tmp_dir,)) )[0] # Render the template @@ -74,8 +74,66 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): "Template file did not contain our test string", ) + def test_multiple_custom_template_directories(self): + """Tests that directories are searched in the right order if multiple custom + template directories are provided. + """ + # Create two temporary directories on the filesystem. + tempdirs = [ + tempfile.TemporaryDirectory(), + tempfile.TemporaryDirectory(), + ] + + # Create one template in each directory, whose content is the index of the + # directory in the list. + template_filename = "my_template.html.j2" + for i in range(len(tempdirs)): + tempdir = tempdirs[i] + template_path = os.path.join(tempdir.name, template_filename) + + with open(template_path, "w") as fp: + fp.write(str(i)) + fp.flush() + + # Retrieve the template. + template = ( + self.hs.config.read_templates( + [template_filename], + (td.name for td in tempdirs), + ) + )[0] + + # Test that we got the template we dropped in the first directory in the list. + self.assertEqual(template.render(), "0") + + # Add another template, this one only in the second directory in the list, so we + # can test that the second directory is still searched into when no matching file + # could be found in the first one. + other_template_name = "my_other_template.html.j2" + other_template_path = os.path.join(tempdirs[1].name, other_template_name) + + with open(other_template_path, "w") as fp: + fp.write("hello world") + fp.flush() + + # Retrieve the template. + template = ( + self.hs.config.read_templates( + [other_template_name], + (td.name for td in tempdirs), + ) + )[0] + + # Test that the file has the expected content. + self.assertEqual(template.render(), "hello world") + + # Cleanup the temporary directories manually since we're not using a context + # manager. + for td in tempdirs: + td.cleanup() + def test_loading_template_from_nonexistent_custom_directory(self): with self.assertRaises(ConfigError): self.hs.config.read_templates( - ["some_filename.html"], "a_nonexistent_directory" + ["some_filename.html"], ("a_nonexistent_directory",) ) diff --git a/tests/config/test_server.py b/tests/config/test_server.py index 6f2b9e997d..b6f21294ba 100644 --- a/tests/config/test_server.py +++ b/tests/config/test_server.py @@ -35,7 +35,7 @@ class ServerConfigTestCase(unittest.TestCase): def test_unsecure_listener_no_listeners_open_private_ports_false(self): conf = yaml.safe_load( ServerConfig().generate_config_section( - "che.org", "/data_dir_path", False, None + "che.org", "/data_dir_path", False, None, config_dir_path="CONFDIR" ) ) @@ -55,7 +55,7 @@ class ServerConfigTestCase(unittest.TestCase): def test_unsecure_listener_no_listeners_open_private_ports_true(self): conf = yaml.safe_load( ServerConfig().generate_config_section( - "che.org", "/data_dir_path", True, None + "che.org", "/data_dir_path", True, None, config_dir_path="CONFDIR" ) ) @@ -89,7 +89,7 @@ class ServerConfigTestCase(unittest.TestCase): conf = yaml.safe_load( ServerConfig().generate_config_section( - "this.one.listens", "/data_dir_path", True, listeners + "this.one.listens", "/data_dir_path", True, listeners, "CONFDIR" ) ) @@ -123,7 +123,7 @@ class ServerConfigTestCase(unittest.TestCase): conf = yaml.safe_load( ServerConfig().generate_config_section( - "this.one.listens", "/data_dir_path", True, listeners + "this.one.listens", "/data_dir_path", True, listeners, "CONFDIR" ) ) diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 3f41e99950..3b3866bff8 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -17,12 +17,12 @@ from unittest.mock import Mock import attr from synapse.api.constants import EduTypes -from synapse.events.presence_router import PresenceRouter +from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState from synapse.module_api import ModuleApi from synapse.rest import admin -from synapse.rest.client.v1 import login, presence, room +from synapse.rest.client import login, presence, room from synapse.types import JsonDict, StreamToken, create_requester from tests.handlers.test_sync import generate_sync_config @@ -34,7 +34,7 @@ class PresenceRouterTestConfig: users_who_should_receive_all_presence = attr.ib(type=List[str], default=[]) -class PresenceRouterTestModule: +class LegacyPresenceRouterTestModule: def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi): self._config = config self._module_api = module_api @@ -77,6 +77,53 @@ class PresenceRouterTestModule: return config +class PresenceRouterTestModule: + def __init__(self, config: PresenceRouterTestConfig, api: ModuleApi): + self._config = config + self._module_api = api + api.register_presence_router_callbacks( + get_users_for_states=self.get_users_for_states, + get_interested_users=self.get_interested_users, + ) + + async def get_users_for_states( + self, state_updates: Iterable[UserPresenceState] + ) -> Dict[str, Set[UserPresenceState]]: + users_to_state = { + user_id: set(state_updates) + for user_id in self._config.users_who_should_receive_all_presence + } + return users_to_state + + async def get_interested_users( + self, user_id: str + ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + if user_id in self._config.users_who_should_receive_all_presence: + return PresenceRouter.ALL_USERS + + return set() + + @staticmethod + def parse_config(config_dict: dict) -> PresenceRouterTestConfig: + """Parse a configuration dictionary from the homeserver config, do + some validation and return a typed PresenceRouterConfig. + + Args: + config_dict: The configuration dictionary. + + Returns: + A validated config object. + """ + # Initialise a typed config object + config = PresenceRouterTestConfig() + + config.users_who_should_receive_all_presence = config_dict.get( + "users_who_should_receive_all_presence" + ) + + return config + + class PresenceRouterTestCase(FederatingHomeserverTestCase): servlets = [ admin.register_servlets, @@ -86,9 +133,17 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ] def make_homeserver(self, reactor, clock): - return self.setup_test_homeserver( + hs = self.setup_test_homeserver( federation_transport_client=Mock(spec=["send_transaction"]), ) + # Load the modules into the homeserver + module_api = hs.get_module_api() + for module, config in hs.config.modules.loaded_modules: + module(config=config, api=module_api) + + load_legacy_presence_router(hs) + + return hs def prepare(self, reactor, clock, homeserver): self.sync_handler = self.hs.get_sync_handler() @@ -98,7 +153,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): { "presence": { "presence_router": { - "module": __name__ + ".PresenceRouterTestModule", + "module": __name__ + ".LegacyPresenceRouterTestModule", "config": { "users_who_should_receive_all_presence": [ "@presence_gobbler:test", @@ -109,7 +164,28 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): "send_federation": True, } ) + def test_receiving_all_presence_legacy(self): + self.receiving_all_presence_test_body() + + @override_config( + { + "modules": [ + { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler:test", + ] + }, + }, + ], + "send_federation": True, + } + ) def test_receiving_all_presence(self): + self.receiving_all_presence_test_body() + + def receiving_all_presence_test_body(self): """Test that a user that does not share a room with another other can receive presence for them, due to presence routing. """ @@ -203,7 +279,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): { "presence": { "presence_router": { - "module": __name__ + ".PresenceRouterTestModule", + "module": __name__ + ".LegacyPresenceRouterTestModule", "config": { "users_who_should_receive_all_presence": [ "@presence_gobbler1:test", @@ -216,7 +292,30 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): "send_federation": True, } ) + def test_send_local_online_presence_to_with_module_legacy(self): + self.send_local_online_presence_to_with_module_test_body() + + @override_config( + { + "modules": [ + { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler1:test", + "@presence_gobbler2:test", + "@far_away_person:island", + ] + }, + }, + ], + "send_federation": True, + } + ) def test_send_local_online_presence_to_with_module(self): + self.send_local_online_presence_to_with_module_test_body() + + def send_local_online_presence_to_with_module_test_body(self): """Tests that send_local_presence_to_users sends local online presence to a set of specified local and remote users, with a custom PresenceRouter module enabled. """ diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 48e98aac79..ca27388ae8 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -14,7 +14,7 @@ from synapse.events.snapshot import EventContext from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest from tests.test_utils.event_injection import create_event diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index e2a5fc018c..5446fda5e7 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -322,7 +322,7 @@ class PruneEventTestCase(unittest.TestCase): }, ) - # After MSC3083, alias events have no special behavior. + # After MSC3083, the allow key is protected from redaction. self.run_test( { "type": "m.room.join_rules", @@ -341,7 +341,51 @@ class PruneEventTestCase(unittest.TestCase): "signatures": {}, "unsigned": {}, }, - room_version=RoomVersions.MSC3083, + room_version=RoomVersions.V8, + ) + + def test_member(self): + """Member events have changed behavior starting with MSC3375.""" + self.run_test( + { + "type": "m.room.member", + "event_id": "$test:domain", + "content": { + "membership": "join", + "join_authorised_via_users_server": "@user:domain", + "other_key": "stripped", + }, + }, + { + "type": "m.room.member", + "event_id": "$test:domain", + "content": {"membership": "join"}, + "signatures": {}, + "unsigned": {}, + }, + ) + + # After MSC3375, the join_authorised_via_users_server key is protected + # from redaction. + self.run_test( + { + "type": "m.room.member", + "content": { + "membership": "join", + "join_authorised_via_users_server": "@user:domain", + "other_key": "stripped", + }, + }, + { + "type": "m.room.member", + "content": { + "membership": "join", + "join_authorised_via_users_server": "@user:domain", + }, + "signatures": {}, + "unsigned": {}, + }, + room_version=RoomVersions.V9, ) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 1a809b2a6a..7b486aba4a 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -16,7 +16,7 @@ from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import UserID from tests import unittest diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 802c5ad299..f0aa8ed9db 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -6,7 +6,7 @@ from synapse.events import EventBase from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.units import Edu from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.util.retryutils import NotRetryingDestination from tests.test_utils import event_injection, make_awaitable diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b00dd143d6..65b18fbd7a 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.api.constants import RoomEncryptionAlgorithms from synapse.rest import admin -from synapse.rest.client.v1 import login +from synapse.rest.client import login from synapse.types import JsonDict, ReadReceipt from tests.test_utils import make_awaitable diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 1737891564..0b60cc4261 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -19,7 +19,7 @@ from parameterized import parameterized from synapse.events import make_event_from_dict from synapse.federation.federation_server import server_matches_acl_event from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index aab44bce4a..663960ff53 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -18,7 +18,7 @@ from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.events import builder from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import RoomAlias @@ -208,7 +208,7 @@ class FederationKnockingTestCase( async def _check_event_auth(origin, event, context, *args, **kwargs): return context - homeserver.get_federation_handler()._check_event_auth = _check_event_auth + homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth return super().prepare(reactor, clock, homeserver) diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 18a734daf4..59de1142b1 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -15,12 +15,10 @@ from collections import Counter from unittest.mock import Mock -import synapse.api.errors -import synapse.handlers.admin import synapse.rest.admin import synapse.storage from synapse.api.constants import EventTypes -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 024c5e963c..43998020b2 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -133,11 +133,131 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.assertEquals(result.room_id, room_id) self.assertEquals(result.servers, servers) - def _mkservice(self, is_interested): + def test_get_3pe_protocols_no_appservices(self): + self.mock_store.get_app_services.return_value = [] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_no_protocols(self): + service = self._mkservice(False, []) + self.mock_store.get_app_services.return_value = [service] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_protocol_no_response(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals(response, {}) + + def test_get_3pe_protocols_select_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_multiple_protocol(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["other-protocol"]) + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called() + self.assertEquals( + response, + { + "my-protocol": {"x-protocol-data": 42, "instances": []}, + "other-protocol": {"x-protocol-data": 42, "instances": []}, + }, + ) + + def test_get_3pe_protocols_multiple_info(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["my-protocol"]) + + async def get_3pe_protocol(service, unusedProtocol): + if service == service_one: + return { + "x-protocol-data": 42, + "instances": [{"desc": "Alice's service"}], + } + if service == service_two: + return { + "x-protocol-data": 36, + "x-not-used": 45, + "instances": [{"desc": "Bob's service"}], + } + raise Exception("Unexpected service") + + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol = get_3pe_protocol + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + # It's expected that the second service's data doesn't appear in the response + self.assertEquals( + response, + { + "my-protocol": { + "x-protocol-data": 42, + "instances": [ + { + "desc": "Alice's service", + }, + {"desc": "Bob's service"}, + ], + }, + }, + ) + + def _mkservice(self, is_interested, protocols=None): service = Mock() service.is_interested.return_value = make_awaitable(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" + service.protocols = protocols return service def _mkservice_alias(self, is_interested_in_alias): diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 7a8041ab44..a0a48b564e 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -19,7 +19,7 @@ import synapse import synapse.api.errors from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig -from synapse.rest.client.v1 import directory, login, room +from synapse.rest.client import directory, login, room from synapse.types import RoomAlias, create_requester from tests import unittest diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 4140fcefc2..6c67a16de9 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -22,7 +22,7 @@ from synapse.events import EventBase from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.util.stringutils import random_string from tests import unittest @@ -130,7 +130,9 @@ class FederationTestCase(unittest.HomeserverTestCase): ) with LoggingContext("send_rejected"): - d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + d = run_in_background( + self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev + ) self.get_success(d) # that should have been rejected @@ -182,7 +184,9 @@ class FederationTestCase(unittest.HomeserverTestCase): ) with LoggingContext("send_rejected"): - d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + d = run_in_background( + self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev + ) self.get_success(d) # that should have been rejected @@ -311,7 +315,9 @@ class FederationTestCase(unittest.HomeserverTestCase): with LoggingContext("receive_pdu"): # Fake the OTHER_SERVER federating the message event over to our local homeserver d = run_in_background( - self.handler.on_receive_pdu, OTHER_SERVER, message_event + self.hs.get_federation_event_handler().on_receive_pdu, + OTHER_SERVER, + message_event, ) self.get_success(d) @@ -382,7 +388,9 @@ class FederationTestCase(unittest.HomeserverTestCase): join_event.signatures[other_server] = {"x": "y"} with LoggingContext("send_join"): d = run_in_background( - self.handler.on_send_membership_event, other_server, join_event + self.hs.get_federation_event_handler().on_send_membership_event, + other_server, + join_event, ) self.get_success(d) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index a8a9fc5b62..8a8d369fac 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -18,7 +18,7 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import create_requester from synapse.util.stringutils import random_string diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 32651db096..38e6d9f536 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -20,8 +20,7 @@ from unittest.mock import Mock from twisted.internet import defer import synapse -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import devices +from synapse.rest.client import devices, login from synapse.types import JsonDict from tests import unittest diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 18e92e90d7..671dc7d083 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Optional from unittest.mock import Mock, call from signedjson.key import generate_signing_key @@ -33,7 +33,7 @@ from synapse.handlers.presence import ( handle_update, ) from synapse.rest import admin -from synapse.rest.client.v1 import room +from synapse.rest.client import room from synapse.types import UserID, get_domain_from_id from tests import unittest @@ -339,8 +339,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): class PresenceTimeoutTestCase(unittest.TestCase): + """Tests different timers and that the timer does not change `status_msg` of user.""" + def test_idle_timer(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -348,12 +351,14 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.ONLINE, last_active_ts=now - IDLE_TIMER - 1, last_user_sync_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) + self.assertEquals(new_state.status_msg, status_msg) def test_busy_no_idle(self): """ @@ -361,6 +366,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): presence state into unavailable. """ user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -368,15 +374,18 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.BUSY, last_active_ts=now - IDLE_TIMER - 1, last_user_sync_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.BUSY) + self.assertEquals(new_state.status_msg, status_msg) def test_sync_timeout(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -384,15 +393,18 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.ONLINE, last_active_ts=0, last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.OFFLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_sync_online(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -400,6 +412,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.ONLINE, last_active_ts=now - SYNC_ONLINE_TIMEOUT - 1, last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout( @@ -408,9 +421,11 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.ONLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_federation_ping(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -419,12 +434,13 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_active_ts=now, last_user_sync_ts=now, last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state, new_state) + self.assertEquals(state, new_state) def test_no_timeout(self): user_id = "@foo:bar" @@ -444,6 +460,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): def test_federation_timeout(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -452,6 +469,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_active_ts=now, last_user_sync_ts=now, last_federation_update_ts=now - FEDERATION_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout( @@ -460,9 +478,11 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.OFFLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_last_active(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -471,6 +491,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_active_ts=now - LAST_ACTIVE_GRANULARITY - 1, last_user_sync_ts=now, last_federation_update_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) @@ -516,6 +537,144 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase): ) self.assertEqual(state.state, PresenceState.OFFLINE) + def test_user_goes_offline_by_timeout_status_msg_remain(self): + """Test that if a user doesn't update the records for a while + users presence goes `OFFLINE` because of timeout and `status_msg` remains. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Check that if we wait a while without telling the handler the user has + # stopped syncing that their presence state doesn't get timed out. + self.reactor.advance(SYNC_ONLINE_TIMEOUT / 2) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.ONLINE) + self.assertEqual(state.status_msg, status_msg) + + # Check that if the timeout fires, then the syncing user gets timed out + self.reactor.advance(SYNC_ONLINE_TIMEOUT) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + # status_msg should remain even after going offline + self.assertEqual(state.state, PresenceState.OFFLINE) + self.assertEqual(state.status_msg, status_msg) + + def test_user_goes_offline_manually_with_no_status_msg(self): + """Test that if a user change presence manually to `OFFLINE` + and no status is set, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as offline + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), {"presence": PresenceState.OFFLINE} + ) + ) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + self.assertEqual(state.status_msg, None) + + def test_user_goes_offline_manually_with_status_msg(self): + """Test that if a user change presence manually to `OFFLINE` + and a status is set, that `status_msg` appears. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as offline + self._set_presencestate_with_status_msg( + user_id, PresenceState.OFFLINE, "And now here." + ) + + def test_user_reset_online_with_no_status(self): + """Test that if a user set again the presence manually + and no status is set, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as online again + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), {"presence": PresenceState.ONLINE} + ) + ) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + # status_msg should remain even after going offline + self.assertEqual(state.state, PresenceState.ONLINE) + self.assertEqual(state.status_msg, None) + + def test_set_presence_with_status_msg_none(self): + """Test that if a user set again the presence manually + and status is `None`, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as online and `status_msg = None` + self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) + + def _set_presencestate_with_status_msg( + self, user_id: str, state: PresenceState, status_msg: Optional[str] + ): + """Set a PresenceState and status_msg and check the result. + + Args: + user_id: User for that the status is to be set. + PresenceState: The new PresenceState. + status_msg: Status message that is to be set. + """ + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), + {"presence": state, "status_msg": status_msg}, + ) + ) + + new_state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(new_state.state, state) + self.assertEqual(new_state.status_msg, status_msg) + class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): @@ -726,7 +885,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.federation_sender = hs.get_federation_sender() self.event_builder_factory = hs.get_event_builder_factory() - self.federation_handler = hs.get_federation_handler() + self.federation_event_handler = hs.get_federation_event_handler() self.presence_handler = hs.get_presence_handler() # self.event_builder_for_2 = EventBuilderFactory(hs) @@ -867,7 +1026,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None) ) - self.get_success(self.federation_handler.on_receive_pdu(hostname, event)) + self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event)) # Check that it was successfully persisted. self.get_success(self.store.get_event(event.event_id)) diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 93a9a084b2..732a12c9bd 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -286,6 +286,29 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) + def test_handles_string_data(self): + """ + Tests that an invalid shape for read-receipts is handled. + Context: https://github.com/matrix-org/synapse/issues/10603 + """ + + self._test_filters_hidden( + [ + { + "content": { + "$14356419edgd14394fHBLK:matrix.org": { + "m.read": { + "@rikj:jki.re": "string", + } + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + }, + ], + [], + ) + def _test_filters_hidden( self, events: List[JsonDict], expected_output: List[JsonDict] ): diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py new file mode 100644 index 0000000000..fcde5dab72 --- /dev/null +++ b/tests/handlers/test_room.py @@ -0,0 +1,108 @@ +import synapse +from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms +from synapse.rest.client import login, room + +from tests import unittest +from tests.unittest import override_config + + +class EncryptedByDefaultTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + ] + + @override_config({"encryption_enabled_by_default_for_room_type": "all"}) + def test_encrypted_by_default_config_option_all(self): + """Tests that invite-only and non-invite-only rooms have encryption enabled by + default when the config option encryption_enabled_by_default_for_room_type is "all". + """ + # Create a user + user = self.register_user("user", "pass") + user_token = self.login(user, "pass") + + # Create an invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) + + # Check that the room has an encryption state event + event_content = self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + ) + self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) + + # Create a non invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) + + # Check that the room has an encryption state event + event_content = self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + ) + self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) + + @override_config({"encryption_enabled_by_default_for_room_type": "invite"}) + def test_encrypted_by_default_config_option_invite(self): + """Tests that only new, invite-only rooms have encryption enabled by default when + the config option encryption_enabled_by_default_for_room_type is "invite". + """ + # Create a user + user = self.register_user("user", "pass") + user_token = self.login(user, "pass") + + # Create an invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) + + # Check that the room has an encryption state event + event_content = self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + ) + self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) + + # Create a non invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) + + # Check that the room does not have an encryption state event + self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + expect_code=404, + ) + + @override_config({"encryption_enabled_by_default_for_room_type": "off"}) + def test_encrypted_by_default_config_option_off(self): + """Tests that neither new invite-only nor non-invite-only rooms have encryption + enabled by default when the config option + encryption_enabled_by_default_for_room_type is "off". + """ + # Create a user + user = self.register_user("user", "pass") + user_token = self.login(user, "pass") + + # Create an invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) + + # Check that the room does not have an encryption state event + self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + expect_code=404, + ) + + # Create a non invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) + + # Check that the room does not have an encryption state event + self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + expect_code=404, + ) diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py new file mode 100644 index 0000000000..d3d0bf1ac5 --- /dev/null +++ b/tests/handlers/test_room_summary.py @@ -0,0 +1,992 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Iterable, List, Optional, Tuple +from unittest import mock + +from synapse.api.constants import ( + EventContentFields, + EventTypes, + HistoryVisibility, + JoinRules, + Membership, + RestrictedJoinRuleTypes, + RoomTypes, +) +from synapse.api.errors import AuthError, NotFoundError, SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.events import make_event_from_dict +from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEntry +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID + +from tests import unittest + + +def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0): + result = mock.Mock(name=room_id) + result.room_id = room_id + result.content = {} + result.origin_server_ts = origin_server_ts + if order is not None: + result.content["order"] = order + return result + + +def _order(*events): + return sorted(events, key=_child_events_comparison_key) + + +class TestSpaceSummarySort(unittest.TestCase): + def test_no_order_last(self): + """An event with no ordering is placed behind those with an ordering.""" + ev1 = _create_event("!abc:test") + ev2 = _create_event("!xyz:test", "xyz") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + def test_order(self): + """The ordering should be used.""" + ev1 = _create_event("!abc:test", "xyz") + ev2 = _create_event("!xyz:test", "abc") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + def test_order_origin_server_ts(self): + """Origin server is a tie-breaker for ordering.""" + ev1 = _create_event("!abc:test", origin_server_ts=10) + ev2 = _create_event("!xyz:test", origin_server_ts=30) + + self.assertEqual([ev1, ev2], _order(ev1, ev2)) + + def test_order_room_id(self): + """Room ID is a final tie-breaker for ordering.""" + ev1 = _create_event("!abc:test") + ev2 = _create_event("!xyz:test") + + self.assertEqual([ev1, ev2], _order(ev1, ev2)) + + def test_invalid_ordering_type(self): + """Invalid orderings are considered the same as missing.""" + ev1 = _create_event("!abc:test", 1) + ev2 = _create_event("!xyz:test", "xyz") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", {}) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", []) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", True) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + def test_invalid_ordering_value(self): + """Invalid orderings are considered the same as missing.""" + ev1 = _create_event("!abc:test", "foo\n") + ev2 = _create_event("!xyz:test", "xyz") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", "a" * 51) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + +class SpaceSummaryTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): + self.hs = hs + self.handler = self.hs.get_room_summary_handler() + + # Create a user. + self.user = self.register_user("user", "pass") + self.token = self.login("user", "pass") + + # Create a space and a child room. + self.space = self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + self.room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, self.room, self.token) + + def _add_child( + self, space_id: str, room_id: str, token: str, order: Optional[str] = None + ) -> None: + """Add a child room to a space.""" + content: JsonDict = {"via": [self.hs.hostname]} + if order is not None: + content["order"] = order + self.helper.send_state( + space_id, + event_type=EventTypes.SpaceChild, + body=content, + tok=token, + state_key=room_id, + ) + + def _assert_rooms( + self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] + ) -> None: + """ + Assert that the expected room IDs and events are in the response. + + Args: + result: The result from the API call. + rooms_and_children: An iterable of tuples where each tuple is: + The expected room ID. + The expected IDs of any children rooms. + """ + room_ids = [] + children_ids = [] + for room_id, children in rooms_and_children: + room_ids.append(room_id) + if children: + children_ids.extend([(room_id, child_id) for child_id in children]) + self.assertCountEqual( + [room.get("room_id") for room in result["rooms"]], room_ids + ) + self.assertCountEqual( + [ + (event.get("room_id"), event.get("state_key")) + for event in result["events"] + ], + children_ids, + ) + + def _assert_hierarchy( + self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] + ) -> None: + """ + Assert that the expected room IDs are in the response. + + Args: + result: The result from the API call. + rooms_and_children: An iterable of tuples where each tuple is: + The expected room ID. + The expected IDs of any children rooms. + """ + result_room_ids = [] + result_children_ids = [] + for result_room in result["rooms"]: + result_room_ids.append(result_room["room_id"]) + result_children_ids.append( + [ + (cs["room_id"], cs["state_key"]) + for cs in result_room.get("children_state") + ] + ) + + room_ids = [] + children_ids = [] + for room_id, children in rooms_and_children: + room_ids.append(room_id) + children_ids.append([(room_id, child_id) for child_id in children]) + + # Note that order matters. + self.assertEqual(result_room_ids, room_ids) + self.assertEqual(result_children_ids, children_ids) + + def _poke_fed_invite(self, room_id: str, from_user: str) -> None: + """ + Creates a invite (as if received over federation) for the room from the + given hostname. + + Args: + room_id: The room ID to issue an invite for. + fed_hostname: The user to invite from. + """ + # Poke an invite over federation into the database. + fed_handler = self.hs.get_federation_handler() + fed_hostname = UserID.from_string(from_user).domain + event = make_event_from_dict( + { + "room_id": room_id, + "event_id": "!abcd:" + fed_hostname, + "type": EventTypes.Member, + "sender": from_user, + "state_key": self.user, + "content": {"membership": Membership.INVITE}, + "prev_events": [], + "auth_events": [], + "depth": 1, + "origin_server_ts": 1234, + } + ) + self.get_success( + fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) + ) + + def test_simple_space(self): + """Test a simple space with a single room.""" + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + # The result should have the space and the room in it, along with a link + # from space -> room. + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_visibility(self): + """A user not in a space cannot inspect it.""" + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + + # The user can see the space since it is publicly joinable. + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # If the space is made invite-only, it should no longer be viewable. + self.helper.send_state( + self.space, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) + + # If the space is made world-readable it should return a result. + self.helper.send_state( + self.space, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # Make it not world-readable again and confirm it results in an error. + self.helper.send_state( + self.space, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.token, + ) + self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) + + # Join the space and results should be returned. + self.helper.invite(self.space, targ=user2, tok=self.token) + self.helper.join(self.space, user2, tok=token2) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # Attempting to view an unknown room returns the same error. + self.get_failure( + self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname), + AuthError, + ) + self.get_failure( + self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname), + AuthError, + ) + + def _create_room_with_join_rule( + self, join_rule: str, room_version: Optional[str] = None, **extra_content + ) -> str: + """Create a room with the given join rule and add it to the space.""" + room_id = self.helper.create_room_as( + self.user, + room_version=room_version, + tok=self.token, + extra_content={ + "initial_state": [ + { + "type": EventTypes.JoinRules, + "state_key": "", + "content": { + "join_rule": join_rule, + **extra_content, + }, + } + ] + }, + ) + self._add_child(self.space, room_id, self.token) + return room_id + + def test_filtering(self): + """ + Rooms should be properly filtered to only include rooms the user has access to. + """ + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + + # Create a few rooms which will have different properties. + public_room = self._create_room_with_join_rule(JoinRules.PUBLIC) + knock_room = self._create_room_with_join_rule( + JoinRules.KNOCK, room_version=RoomVersions.V7.identifier + ) + not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(invited_room, targ=user2, tok=self.token) + restricted_room = self._create_room_with_join_rule( + JoinRules.RESTRICTED, + room_version=RoomVersions.V8.identifier, + allow=[], + ) + restricted_accessible_room = self._create_room_with_join_rule( + JoinRules.RESTRICTED, + room_version=RoomVersions.V8.identifier, + allow=[ + { + "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP, + "room_id": self.space, + "via": [self.hs.hostname], + } + ], + ) + world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.send_state( + world_readable_room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + joined_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(joined_room, targ=user2, tok=self.token) + self.helper.join(joined_room, user2, tok=token2) + + # Join the space. + self.helper.join(self.space, user2, tok=token2) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + expected = [ + ( + self.space, + [ + self.room, + public_room, + knock_room, + not_invited_room, + invited_room, + restricted_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ), + (self.room, ()), + (public_room, ()), + (knock_room, ()), + (invited_room, ()), + (restricted_accessible_room, ()), + (world_readable_room, ()), + (joined_room, ()), + ] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + def test_complex_space(self): + """ + Create a "complex" space to see how it handles things like loops and subspaces. + """ + # Create an inaccessible room. + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + room2 = self.helper.create_room_as(user2, is_public=False, tok=token2) + # This is a bit odd as "user" is adding a room they don't know about, but + # it works for the tests. + self._add_child(self.space, room2, self.token) + + # Create a subspace under the space with an additional room in it. + subspace = self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + subroom = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, subspace, token=self.token) + self._add_child(subspace, subroom, token=self.token) + # Also add the two rooms from the space into this subspace (causing loops). + self._add_child(subspace, self.room, token=self.token) + self._add_child(subspace, room2, self.token) + + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + + # The result should include each room a single time and each link. + expected = [ + (self.space, [self.room, room2, subspace]), + (self.room, ()), + (subspace, [subroom, self.room, room2]), + (subroom, ()), + ] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_pagination(self): + """Test simple pagination works.""" + room_ids = [] + for i in range(1, 10): + room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, room, self.token, order=str(i)) + room_ids.append(room) + # The room created initially doesn't have an order, so comes last. + room_ids.append(self.room) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, limit=7) + ) + # The result should have the space and all of the links, plus some of the + # rooms and a pagination token. + expected: List[Tuple[str, Iterable[str]]] = [(self.space, room_ids)] + expected += [(room_id, ()) for room_id in room_ids[:6]] + self._assert_hierarchy(result, expected) + self.assertIn("next_batch", result) + + # Check the next page. + result = self.get_success( + self.handler.get_room_hierarchy( + self.user, self.space, limit=5, from_token=result["next_batch"] + ) + ) + # The result should have the space and the room in it, along with a link + # from space -> room. + expected = [(room_id, ()) for room_id in room_ids[6:]] + self._assert_hierarchy(result, expected) + self.assertNotIn("next_batch", result) + + def test_invalid_pagination_token(self): + """An invalid pagination token, or changing other parameters, shoudl be rejected.""" + room_ids = [] + for i in range(1, 10): + room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, room, self.token, order=str(i)) + room_ids.append(room) + # The room created initially doesn't have an order, so comes last. + room_ids.append(self.room) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, limit=7) + ) + self.assertIn("next_batch", result) + + # Changing the room ID, suggested-only, or max-depth causes an error. + self.get_failure( + self.handler.get_room_hierarchy( + self.user, self.room, from_token=result["next_batch"] + ), + SynapseError, + ) + self.get_failure( + self.handler.get_room_hierarchy( + self.user, + self.space, + suggested_only=True, + from_token=result["next_batch"], + ), + SynapseError, + ) + self.get_failure( + self.handler.get_room_hierarchy( + self.user, self.space, max_depth=0, from_token=result["next_batch"] + ), + SynapseError, + ) + + # An invalid token is ignored. + self.get_failure( + self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"), + SynapseError, + ) + + def test_max_depth(self): + """Create a deep tree to test the max depth against.""" + spaces = [self.space] + rooms = [self.room] + for _ in range(5): + spaces.append( + self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": { + EventContentFields.ROOM_TYPE: RoomTypes.SPACE + } + }, + ) + ) + self._add_child(spaces[-2], spaces[-1], self.token) + rooms.append(self.helper.create_room_as(self.user, tok=self.token)) + self._add_child(spaces[-1], rooms[-1], self.token) + + # Test just the space itself. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=0) + ) + expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])] + self._assert_hierarchy(result, expected) + + # A single additional layer. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=1) + ) + expected += [ + (rooms[0], ()), + (spaces[1], [rooms[1], spaces[2]]), + ] + self._assert_hierarchy(result, expected) + + # A few layers. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=3) + ) + expected += [ + (rooms[1], ()), + (spaces[2], [rooms[2], spaces[3]]), + (rooms[2], ()), + (spaces[3], [rooms[3], spaces[4]]), + ] + self._assert_hierarchy(result, expected) + + def test_unknown_room_version(self): + """ + If an room with an unknown room version is encountered it should not cause + the entire summary to skip. + """ + # Poke the database and update the room version to an unknown one. + self.get_success( + self.hs.get_datastores().main.db_pool.simple_update( + "rooms", + keyvalues={"room_id": self.room}, + updatevalues={"room_version": "unknown-room-version"}, + desc="updated-room-version", + ) + ) + + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + # The result should have only the space, along with a link from space -> room. + expected = [(self.space, [self.room])] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_fed_complex(self): + """ + Return data over federation and ensure that it is handled properly. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + subroom = "#subroom:" + fed_hostname + + # Generate some good data, and some bad data: + # + # * Event *back* to the root room. + # * Unrelated events / rooms + # * Multiple levels of events (in a not-useful order, e.g. grandchild + # events before child events). + + # Note that these entries are brief, but should contain enough info. + requested_room_entry = _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + [ + { + "type": EventTypes.SpaceChild, + "room_id": subspace, + "state_key": subroom, + "content": {"via": [fed_hostname]}, + } + ], + ) + child_room = { + "room_id": subroom, + "world_readable": True, + } + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [ + requested_room_entry, + _RoomEntry( + subroom, + { + "room_id": subroom, + "world_readable": True, + }, + ), + ] + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return requested_room_entry, {subroom: child_room}, set() + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + expected = [ + (self.space, [self.room, subspace]), + (self.room, ()), + (subspace, [subroom]), + (subroom, ()), + ] + self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_fed_filtering(self): + """ + Rooms returned over federation should be properly filtered to only include + rooms the user has access to. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + + # Create a few rooms which will have different properties. + public_room = "#public:" + fed_hostname + knock_room = "#knock:" + fed_hostname + not_invited_room = "#not_invited:" + fed_hostname + invited_room = "#invited:" + fed_hostname + restricted_room = "#restricted:" + fed_hostname + restricted_accessible_room = "#restricted_accessible:" + fed_hostname + world_readable_room = "#world_readable:" + fed_hostname + joined_room = self.helper.create_room_as(self.user, tok=self.token) + + # Poke an invite over federation into the database. + self._poke_fed_invite(invited_room, "@remote:" + fed_hostname) + + # Note that these entries are brief, but should contain enough info. + children_rooms = ( + ( + public_room, + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + ), + ( + knock_room, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + ), + ( + not_invited_room, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ( + invited_room, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ( + restricted_room, + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [], + }, + ), + ( + restricted_accessible_room, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [self.room], + }, + ), + ( + world_readable_room, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + ), + ( + joined_room, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ) + + subspace_room_entry = _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + }, + # Place each room in the sub-space. + [ + { + "type": EventTypes.SpaceChild, + "room_id": subspace, + "state_key": room_id, + "content": {"via": [fed_hostname]}, + } + for room_id, _ in children_rooms + ], + ) + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [subspace_room_entry] + [ + # A copy is made of the room data since the allowed_spaces key + # is removed. + _RoomEntry(child_room[0], dict(child_room[1])) + for child_room in children_rooms + ] + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return subspace_room_entry, dict(children_rooms), set() + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + expected = [ + (self.space, [self.room, subspace]), + (self.room, ()), + ( + subspace, + [ + public_room, + knock_room, + not_invited_room, + invited_room, + restricted_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ), + (public_room, ()), + (knock_room, ()), + (invited_room, ()), + (restricted_accessible_room, ()), + (world_readable_room, ()), + (joined_room, ()), + ] + self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_fed_invited(self): + """ + A room which the user was invited to should be included in the response. + + This differs from test_fed_filtering in that the room itself is being + queried over federation, instead of it being included as a sub-room of + a space in the response. + """ + fed_hostname = self.hs.hostname + "2" + fed_room = "#subroom:" + fed_hostname + + # Poke an invite over federation into the database. + self._poke_fed_invite(fed_room, "@remote:" + fed_hostname) + + fed_room_entry = _RoomEntry( + fed_room, + { + "room_id": fed_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ) + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [fed_room_entry] + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return fed_room_entry, {}, set() + + # Add a room to the space which is on another server. + self._add_child(self.space, fed_room, self.token) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + expected = [ + (self.space, [self.room, fed_room]), + (self.room, ()), + (fed_room, ()), + ] + self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + +class RoomSummaryTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): + self.hs = hs + self.handler = self.hs.get_room_summary_handler() + + # Create a user. + self.user = self.register_user("user", "pass") + self.token = self.login("user", "pass") + + # Create a simple room. + self.room = self.helper.create_room_as(self.user, tok=self.token) + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + + def test_own_room(self): + """Test a simple room created by the requester.""" + result = self.get_success(self.handler.get_room_summary(self.user, self.room)) + self.assertEqual(result.get("room_id"), self.room) + + def test_visibility(self): + """A user not in a private room cannot get its summary.""" + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + + # The user cannot see the room. + self.get_failure(self.handler.get_room_summary(user2, self.room), NotFoundError) + + # If the room is made world-readable it should return a result. + self.helper.send_state( + self.room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + result = self.get_success(self.handler.get_room_summary(user2, self.room)) + self.assertEqual(result.get("room_id"), self.room) + + # Make it not world-readable again and confirm it results in an error. + self.helper.send_state( + self.room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.token, + ) + self.get_failure(self.handler.get_room_summary(user2, self.room), NotFoundError) + + # If the room is made public it should return a result. + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.PUBLIC}, + tok=self.token, + ) + result = self.get_success(self.handler.get_room_summary(user2, self.room)) + self.assertEqual(result.get("room_id"), self.room) + + # Join the space, make it invite-only again and results should be returned. + self.helper.join(self.room, user2, tok=token2) + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + result = self.get_success(self.handler.get_room_summary(user2, self.room)) + self.assertEqual(result.get("room_id"), self.room) diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py new file mode 100644 index 0000000000..6f77b1237c --- /dev/null +++ b/tests/handlers/test_send_email.py @@ -0,0 +1,112 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Tuple + +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.address import IPv4Address +from twisted.internet.defer import ensureDeferred +from twisted.mail import interfaces, smtp + +from tests.server import FakeTransport +from tests.unittest import HomeserverTestCase + + +@implementer(interfaces.IMessageDelivery) +class _DummyMessageDelivery: + def __init__(self): + # (recipient, message) tuples + self.messages: List[Tuple[smtp.Address, bytes]] = [] + + def receivedHeader(self, helo, origin, recipients): + return None + + def validateFrom(self, helo, origin): + return origin + + def record_message(self, recipient: smtp.Address, message: bytes): + self.messages.append((recipient, message)) + + def validateTo(self, user: smtp.User): + return lambda: _DummyMessage(self, user) + + +@implementer(interfaces.IMessageSMTP) +class _DummyMessage: + """IMessageSMTP implementation which saves the message delivered to it + to the _DummyMessageDelivery object. + """ + + def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User): + self._delivery = delivery + self._user = user + self._buffer: List[bytes] = [] + + def lineReceived(self, line): + self._buffer.append(line) + + def eomReceived(self): + message = b"\n".join(self._buffer) + b"\n" + self._delivery.record_message(self._user.dest, message) + return defer.succeed(b"saved") + + def connectionLost(self): + pass + + +class SendEmailHandlerTestCase(HomeserverTestCase): + def test_send_email(self): + """Happy-path test that we can send email to a non-TLS server.""" + h = self.hs.get_send_email_handler() + d = ensureDeferred( + h.send_email( + "foo@bar.com", "test subject", "Tests", "HTML content", "Text content" + ) + ) + # there should be an attempt to connect to localhost:25 + self.assertEqual(len(self.reactor.tcpClients), 1) + (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[ + 0 + ] + self.assertEqual(host, "localhost") + self.assertEqual(port, 25) + + # wire it up to an SMTP server + message_delivery = _DummyMessageDelivery() + server_protocol = smtp.ESMTP() + server_protocol.delivery = message_delivery + # make sure that the server uses the test reactor to set timeouts + server_protocol.callLater = self.reactor.callLater # type: ignore[assignment] + + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor)) + server_protocol.makeConnection( + FakeTransport( + client_protocol, + self.reactor, + peer_address=IPv4Address("TCP", "127.0.0.1", 1234), + ) + ) + + # the message should now get delivered + self.get_success(d, by=0.1) + + # check it arrived + self.assertEqual(len(message_delivery.messages), 1) + user, msg = message_delivery.messages.pop() + self.assertEqual(str(user), "foo@bar.com") + self.assertIn(b"Subject: test subject", msg) diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py deleted file mode 100644 index 3f73ad7f94..0000000000 --- a/tests/handlers/test_space_summary.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Iterable, Optional, Tuple -from unittest import mock - -from synapse.api.constants import ( - EventContentFields, - EventTypes, - HistoryVisibility, - JoinRules, - Membership, - RestrictedJoinRuleTypes, - RoomTypes, -) -from synapse.api.errors import AuthError -from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict -from synapse.handlers.space_summary import _child_events_comparison_key -from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.server import HomeServer -from synapse.types import JsonDict - -from tests import unittest - - -def _create_event(room_id: str, order: Optional[Any] = None): - result = mock.Mock() - result.room_id = room_id - result.content = {} - if order is not None: - result.content["order"] = order - return result - - -def _order(*events): - return sorted(events, key=_child_events_comparison_key) - - -class TestSpaceSummarySort(unittest.TestCase): - def test_no_order_last(self): - """An event with no ordering is placed behind those with an ordering.""" - ev1 = _create_event("!abc:test") - ev2 = _create_event("!xyz:test", "xyz") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - def test_order(self): - """The ordering should be used.""" - ev1 = _create_event("!abc:test", "xyz") - ev2 = _create_event("!xyz:test", "abc") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - def test_order_room_id(self): - """Room ID is a tie-breaker for ordering.""" - ev1 = _create_event("!abc:test", "abc") - ev2 = _create_event("!xyz:test", "abc") - - self.assertEqual([ev1, ev2], _order(ev1, ev2)) - - def test_invalid_ordering_type(self): - """Invalid orderings are considered the same as missing.""" - ev1 = _create_event("!abc:test", 1) - ev2 = _create_event("!xyz:test", "xyz") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", {}) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", []) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", True) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - def test_invalid_ordering_value(self): - """Invalid orderings are considered the same as missing.""" - ev1 = _create_event("!abc:test", "foo\n") - ev2 = _create_event("!xyz:test", "xyz") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", "a" * 51) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - -class SpaceSummaryTestCase(unittest.HomeserverTestCase): - servlets = [ - admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor, clock, hs: HomeServer): - self.hs = hs - self.handler = self.hs.get_space_summary_handler() - - # Create a user. - self.user = self.register_user("user", "pass") - self.token = self.login("user", "pass") - - # Create a space and a child room. - self.space = self.helper.create_room_as( - self.user, - tok=self.token, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - self.room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(self.space, self.room, self.token) - - def _add_child(self, space_id: str, room_id: str, token: str) -> None: - """Add a child room to a space.""" - self.helper.send_state( - space_id, - event_type=EventTypes.SpaceChild, - body={"via": [self.hs.hostname]}, - tok=token, - state_key=room_id, - ) - - def _assert_rooms(self, result: JsonDict, rooms: Iterable[str]) -> None: - """Assert that the expected room IDs are in the response.""" - self.assertCountEqual([room.get("room_id") for room in result["rooms"]], rooms) - - def _assert_events( - self, result: JsonDict, events: Iterable[Tuple[str, str]] - ) -> None: - """Assert that the expected parent / child room IDs are in the response.""" - self.assertCountEqual( - [ - (event.get("room_id"), event.get("state_key")) - for event in result["events"] - ], - events, - ) - - def test_simple_space(self): - """Test a simple space with a single room.""" - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - # The result should have the space and the room in it, along with a link - # from space -> room. - self._assert_rooms(result, [self.space, self.room]) - self._assert_events(result, [(self.space, self.room)]) - - def test_visibility(self): - """A user not in a space cannot inspect it.""" - user2 = self.register_user("user2", "pass") - token2 = self.login("user2", "pass") - - # The user cannot see the space. - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - - # If the space is made world-readable it should return a result. - self.helper.send_state( - self.space, - event_type=EventTypes.RoomHistoryVisibility, - body={"history_visibility": HistoryVisibility.WORLD_READABLE}, - tok=self.token, - ) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, [self.space, self.room]) - self._assert_events(result, [(self.space, self.room)]) - - # Make it not world-readable again and confirm it results in an error. - self.helper.send_state( - self.space, - event_type=EventTypes.RoomHistoryVisibility, - body={"history_visibility": HistoryVisibility.JOINED}, - tok=self.token, - ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - - # Join the space and results should be returned. - self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, [self.space, self.room]) - self._assert_events(result, [(self.space, self.room)]) - - def _create_room_with_join_rule( - self, join_rule: str, room_version: Optional[str] = None, **extra_content - ) -> str: - """Create a room with the given join rule and add it to the space.""" - room_id = self.helper.create_room_as( - self.user, - room_version=room_version, - tok=self.token, - extra_content={ - "initial_state": [ - { - "type": EventTypes.JoinRules, - "state_key": "", - "content": { - "join_rule": join_rule, - **extra_content, - }, - } - ] - }, - ) - self._add_child(self.space, room_id, self.token) - return room_id - - def test_filtering(self): - """ - Rooms should be properly filtered to only include rooms the user has access to. - """ - user2 = self.register_user("user2", "pass") - token2 = self.login("user2", "pass") - - # Create a few rooms which will have different properties. - public_room = self._create_room_with_join_rule(JoinRules.PUBLIC) - knock_room = self._create_room_with_join_rule( - JoinRules.KNOCK, room_version=RoomVersions.V7.identifier - ) - not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE) - invited_room = self._create_room_with_join_rule(JoinRules.INVITE) - self.helper.invite(invited_room, targ=user2, tok=self.token) - restricted_room = self._create_room_with_join_rule( - JoinRules.MSC3083_RESTRICTED, - room_version=RoomVersions.MSC3083.identifier, - allow=[], - ) - restricted_accessible_room = self._create_room_with_join_rule( - JoinRules.MSC3083_RESTRICTED, - room_version=RoomVersions.MSC3083.identifier, - allow=[ - { - "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP, - "room_id": self.space, - "via": [self.hs.hostname], - } - ], - ) - world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE) - self.helper.send_state( - world_readable_room, - event_type=EventTypes.RoomHistoryVisibility, - body={"history_visibility": HistoryVisibility.WORLD_READABLE}, - tok=self.token, - ) - joined_room = self._create_room_with_join_rule(JoinRules.INVITE) - self.helper.invite(joined_room, targ=user2, tok=self.token) - self.helper.join(joined_room, user2, tok=token2) - - # Join the space. - self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - - self._assert_rooms( - result, - [ - self.space, - self.room, - public_room, - knock_room, - invited_room, - restricted_accessible_room, - world_readable_room, - joined_room, - ], - ) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, public_room), - (self.space, knock_room), - (self.space, not_invited_room), - (self.space, invited_room), - (self.space, restricted_room), - (self.space, restricted_accessible_room), - (self.space, world_readable_room), - (self.space, joined_room), - ], - ) - - def test_complex_space(self): - """ - Create a "complex" space to see how it handles things like loops and subspaces. - """ - # Create an inaccessible room. - user2 = self.register_user("user2", "pass") - token2 = self.login("user2", "pass") - room2 = self.helper.create_room_as(user2, is_public=False, tok=token2) - # This is a bit odd as "user" is adding a room they don't know about, but - # it works for the tests. - self._add_child(self.space, room2, self.token) - - # Create a subspace under the space with an additional room in it. - subspace = self.helper.create_room_as( - self.user, - tok=self.token, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - subroom = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(self.space, subspace, token=self.token) - self._add_child(subspace, subroom, token=self.token) - # Also add the two rooms from the space into this subspace (causing loops). - self._add_child(subspace, self.room, token=self.token) - self._add_child(subspace, room2, self.token) - - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - - # The result should include each room a single time and each link. - self._assert_rooms(result, [self.space, self.room, subspace, subroom]) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, room2), - (self.space, subspace), - (subspace, subroom), - (subspace, self.room), - (subspace, room2), - ], - ) - - def test_fed_complex(self): - """ - Return data over federation and ensure that it is handled properly. - """ - fed_hostname = self.hs.hostname + "2" - subspace = "#subspace:" + fed_hostname - subroom = "#subroom:" + fed_hostname - - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - # Return some good data, and some bad data: - # - # * Event *back* to the root room. - # * Unrelated events / rooms - # * Multiple levels of events (in a not-useful order, e.g. grandchild - # events before child events). - - # Note that these entries are brief, but should contain enough info. - rooms = [ - { - "room_id": subspace, - "world_readable": True, - "room_type": RoomTypes.SPACE, - }, - { - "room_id": subroom, - "world_readable": True, - }, - ] - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": subroom, - "content": event_content, - }, - ] - return rooms, events - - # Add a room to the space which is on another server. - self._add_child(self.space, subspace, self.token) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - - self._assert_rooms(result, [self.space, self.room, subspace, subroom]) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, subspace), - (subspace, subroom), - ], - ) - - def test_fed_filtering(self): - """ - Rooms returned over federation should be properly filtered to only include - rooms the user has access to. - """ - fed_hostname = self.hs.hostname + "2" - subspace = "#subspace:" + fed_hostname - - # Create a few rooms which will have different properties. - public_room = "#public:" + fed_hostname - knock_room = "#knock:" + fed_hostname - not_invited_room = "#not_invited:" + fed_hostname - invited_room = "#invited:" + fed_hostname - restricted_room = "#restricted:" + fed_hostname - restricted_accessible_room = "#restricted_accessible:" + fed_hostname - world_readable_room = "#world_readable:" + fed_hostname - joined_room = self.helper.create_room_as(self.user, tok=self.token) - - # Poke an invite over federation into the database. - fed_handler = self.hs.get_federation_handler() - event = make_event_from_dict( - { - "room_id": invited_room, - "event_id": "!abcd:" + fed_hostname, - "type": EventTypes.Member, - "sender": "@remote:" + fed_hostname, - "state_key": self.user, - "content": {"membership": Membership.INVITE}, - "prev_events": [], - "auth_events": [], - "depth": 1, - "origin_server_ts": 1234, - } - ) - self.get_success( - fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) - ) - - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - # Note that these entries are brief, but should contain enough info. - rooms = [ - { - "room_id": public_room, - "world_readable": False, - "join_rules": JoinRules.PUBLIC, - }, - { - "room_id": knock_room, - "world_readable": False, - "join_rules": JoinRules.KNOCK, - }, - { - "room_id": not_invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": restricted_room, - "world_readable": False, - "join_rules": JoinRules.MSC3083_RESTRICTED, - "allowed_spaces": [], - }, - { - "room_id": restricted_accessible_room, - "world_readable": False, - "join_rules": JoinRules.MSC3083_RESTRICTED, - "allowed_spaces": [self.room], - }, - { - "room_id": world_readable_room, - "world_readable": True, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": joined_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ] - - # Place each room in the sub-space. - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": room["room_id"], - "content": event_content, - } - for room in rooms - ] - - # Also include the subspace. - rooms.insert( - 0, - { - "room_id": subspace, - "world_readable": True, - }, - ) - return rooms, events - - # Add a room to the space which is on another server. - self._add_child(self.space, subspace, self.token) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - - self._assert_rooms( - result, - [ - self.space, - self.room, - subspace, - public_room, - knock_room, - invited_room, - restricted_accessible_room, - world_readable_room, - joined_room, - ], - ) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, subspace), - (subspace, public_room), - (subspace, knock_room), - (subspace, not_invited_room), - (subspace, invited_room), - (subspace, restricted_room), - (subspace, restricted_accessible_room), - (subspace, world_readable_room), - (subspace, joined_room), - ], - ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index e4059acda3..1ba4c05b9b 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage.databases.main import stats from tests import unittest diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 84f05f6c58..339c039914 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -12,9 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION +from synapse.api.room_versions import RoomVersions from synapse.handlers.sync import SyncConfig +from synapse.rest import admin +from synapse.rest.client import knock, login, room +from synapse.server import HomeServer from synapse.types import UserID, create_requester import tests.unittest @@ -24,8 +31,14 @@ import tests.utils class SyncTestCase(tests.unittest.HomeserverTestCase): """Tests Sync Handler.""" - def prepare(self, reactor, clock, hs): - self.hs = hs + servlets = [ + admin.register_servlets, + knock.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastore() @@ -68,12 +81,124 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + def test_unknown_room_version(self): + """ + A room with an unknown room version should not break sync (and should be excluded). + """ + inviter = self.register_user("creator", "pass", admin=True) + inviter_tok = self.login("@creator:test", "pass") + + user = self.register_user("user", "pass") + tok = self.login("user", "pass") + + # Do an initial sync on a different device. + requester = create_requester(user) + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, sync_config=generate_sync_config(user, device_id="dev") + ) + ) + + # Create a room as the user. + joined_room = self.helper.create_room_as(user, tok=tok) + + # Invite the user to the room as someone else. + invite_room = self.helper.create_room_as(inviter, tok=inviter_tok) + self.helper.invite(invite_room, targ=user, tok=inviter_tok) + + knock_room = self.helper.create_room_as( + inviter, room_version=RoomVersions.V7.identifier, tok=inviter_tok + ) + self.helper.send_state( + knock_room, + EventTypes.JoinRules, + {"join_rule": JoinRules.KNOCK}, + tok=inviter_tok, + ) + channel = self.make_request( + "POST", + "/_matrix/client/r0/knock/%s" % (knock_room,), + b"{}", + tok, + ) + self.assertEquals(200, channel.code, channel.result) + + # The rooms should appear in the sync response. + result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, sync_config=generate_sync_config(user) + ) + ) + self.assertIn(joined_room, [r.room_id for r in result.joined]) + self.assertIn(invite_room, [r.room_id for r in result.invited]) + self.assertIn(knock_room, [r.room_id for r in result.knocked]) + + # Test a incremental sync (by providing a since_token). + result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config(user, device_id="dev"), + since_token=initial_result.next_batch, + ) + ) + self.assertIn(joined_room, [r.room_id for r in result.joined]) + self.assertIn(invite_room, [r.room_id for r in result.invited]) + self.assertIn(knock_room, [r.room_id for r in result.knocked]) + + # Poke the database and update the room version to an unknown one. + for room_id in (joined_room, invite_room, knock_room): + self.get_success( + self.hs.get_datastores().main.db_pool.simple_update( + "rooms", + keyvalues={"room_id": room_id}, + updatevalues={"room_version": "unknown-room-version"}, + desc="updated-room-version", + ) + ) + + # Blow away caches (supported room versions can only change due to a restart). + self.get_success( + self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() + ) + self.store._get_event_cache.clear() + + # The rooms should be excluded from the sync response. + # Get a new request key. + result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, sync_config=generate_sync_config(user) + ) + ) + self.assertNotIn(joined_room, [r.room_id for r in result.joined]) + self.assertNotIn(invite_room, [r.room_id for r in result.invited]) + self.assertNotIn(knock_room, [r.room_id for r in result.knocked]) + + # The rooms should also not be in an incremental sync. + result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config(user, device_id="dev"), + since_token=initial_result.next_batch, + ) + ) + self.assertNotIn(joined_room, [r.room_id for r in result.joined]) + self.assertNotIn(invite_room, [r.room_id for r in result.invited]) + self.assertNotIn(knock_room, [r.room_id for r in result.knocked]) + + +_request_key = 0 + -def generate_sync_config(user_id: str) -> SyncConfig: +def generate_sync_config( + user_id: str, device_id: Optional[str] = "device_id" +) -> SyncConfig: + """Generate a sync config (with a unique request key).""" + global _request_key + _request_key += 1 return SyncConfig( - user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), + user=UserID.from_string(user_id), filter_collection=DEFAULT_FILTER_COLLECTION, is_guest=False, - request_key="request_key", - device_id="device_id", + request_key=("request_key", _request_key), + device_id=device_id, ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 549876dc85..a91d31ce61 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -16,10 +16,9 @@ from unittest.mock import Mock from twisted.internet import defer import synapse.rest.admin -from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes +from synapse.api.constants import UserTypes from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import user_directory +from synapse.rest.client import login, room, user_directory from synapse.storage.roommember import ProfileInfo from tests import unittest @@ -188,100 +187,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user3", 10)) self.assertEqual(len(s["results"]), 0) - @override_config({"encryption_enabled_by_default_for_room_type": "all"}) - def test_encrypted_by_default_config_option_all(self): - """Tests that invite-only and non-invite-only rooms have encryption enabled by - default when the config option encryption_enabled_by_default_for_room_type is "all". - """ - # Create a user - user = self.register_user("user", "pass") - user_token = self.login(user, "pass") - - # Create an invite-only room as that user - room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) - - # Check that the room has an encryption state event - event_content = self.helper.get_state( - room_id=room_id, - event_type=EventTypes.RoomEncryption, - tok=user_token, - ) - self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) - - # Create a non invite-only room as that user - room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) - - # Check that the room has an encryption state event - event_content = self.helper.get_state( - room_id=room_id, - event_type=EventTypes.RoomEncryption, - tok=user_token, - ) - self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) - - @override_config({"encryption_enabled_by_default_for_room_type": "invite"}) - def test_encrypted_by_default_config_option_invite(self): - """Tests that only new, invite-only rooms have encryption enabled by default when - the config option encryption_enabled_by_default_for_room_type is "invite". - """ - # Create a user - user = self.register_user("user", "pass") - user_token = self.login(user, "pass") - - # Create an invite-only room as that user - room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) - - # Check that the room has an encryption state event - event_content = self.helper.get_state( - room_id=room_id, - event_type=EventTypes.RoomEncryption, - tok=user_token, - ) - self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) - - # Create a non invite-only room as that user - room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) - - # Check that the room does not have an encryption state event - self.helper.get_state( - room_id=room_id, - event_type=EventTypes.RoomEncryption, - tok=user_token, - expect_code=404, - ) - - @override_config({"encryption_enabled_by_default_for_room_type": "off"}) - def test_encrypted_by_default_config_option_off(self): - """Tests that neither new invite-only nor non-invite-only rooms have encryption - enabled by default when the config option - encryption_enabled_by_default_for_room_type is "off". - """ - # Create a user - user = self.register_user("user", "pass") - user_token = self.login(user, "pass") - - # Create an invite-only room as that user - room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) - - # Check that the room does not have an encryption state event - self.helper.get_state( - room_id=room_id, - event_type=EventTypes.RoomEncryption, - tok=user_token, - expect_code=404, - ) - - # Create a non invite-only room as that user - room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) - - # Check that the room does not have an encryption state event - self.helper.get_state( - room_id=room_id, - event_type=EventTypes.RoomEncryption, - tok=user_token, - expect_code=404, - ) - def test_spam_checker(self): """ A user which fails the spam checks will not appear in search results. diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index a37bce08c3..992d8f94fd 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging -from typing import Optional -from unittest.mock import Mock +import os +from typing import Iterable, Optional +from unittest.mock import Mock, patch import treq from netaddr import IPSet @@ -22,11 +24,12 @@ from zope.interface import implementer from twisted.internet import defer from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions +from twisted.internet.interfaces import IProtocolFactory from twisted.internet.protocol import Factory -from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web._newclient import ResponseNeverReceived from twisted.web.client import Agent -from twisted.web.http import HTTPChannel +from twisted.web.http import HTTPChannel, Request from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS @@ -49,24 +52,6 @@ from tests.utils import default_config logger = logging.getLogger(__name__) -test_server_connection_factory = None - - -def get_connection_factory(): - # this needs to happen once, but not until we are ready to run the first test - global test_server_connection_factory - if test_server_connection_factory is None: - test_server_connection_factory = TestServerTLSConnectionFactory( - sanlist=[ - b"DNS:testserv", - b"DNS:target-server", - b"DNS:xn--bcher-kva.com", - b"IP:1.2.3.4", - b"IP:::1", - ] - ) - return test_server_connection_factory - # Once Async Mocks or lambdas are supported this can go away. def generate_resolve_service(result): @@ -100,24 +85,38 @@ class MatrixFederationAgentTests(unittest.TestCase): had_well_known_cache=self.had_well_known_cache, ) - self.agent = MatrixFederationAgent( - reactor=self.reactor, - tls_client_options_factory=self.tls_factory, - user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. - ip_blacklist=IPSet(), - _srv_resolver=self.mock_resolver, - _well_known_resolver=self.well_known_resolver, - ) - - def _make_connection(self, client_factory, expected_sni): + def _make_connection( + self, + client_factory: IProtocolFactory, + ssl: bool = True, + expected_sni: bytes = None, + tls_sanlist: Optional[Iterable[bytes]] = None, + ) -> HTTPChannel: """Builds a test server, and completes the outgoing client connection + Args: + client_factory: the the factory that the + application is trying to use to make the outbound connection. We will + invoke it to build the client Protocol + + ssl: If true, we will expect an ssl connection and wrap + server_factory with a TLSMemoryBIOFactory + False is set only for when proxy expect http connection. + Otherwise federation requests use always https. + + expected_sni: the expected SNI value + + tls_sanlist: list of SAN entries for the TLS cert presented by the server. Returns: - HTTPChannel: the test server + the server Protocol returned by server_factory """ # build the test server - server_tls_protocol = _build_test_server(get_connection_factory()) + server_factory = _get_test_protocol_factory() + if ssl: + server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) + + server_protocol = server_factory.buildProtocol(None) # now, tell the client protocol factory to build the client protocol (it will be a # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an @@ -128,35 +127,39 @@ class MatrixFederationAgentTests(unittest.TestCase): # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( - FakeTransport(server_tls_protocol, self.reactor, client_protocol) + FakeTransport(server_protocol, self.reactor, client_protocol) ) - # tell the server tls protocol to send its stuff back to the client, too - server_tls_protocol.makeConnection( - FakeTransport(client_protocol, self.reactor, server_tls_protocol) + # tell the server protocol to send its stuff back to the client, too + server_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_protocol) ) - # grab a hold of the TLS connection, in case it gets torn down - server_tls_connection = server_tls_protocol._tlsConnection - - # fish the test server back out of the server-side TLS protocol. - http_protocol = server_tls_protocol.wrappedProtocol + if ssl: + # fish the test server back out of the server-side TLS protocol. + http_protocol = server_protocol.wrappedProtocol + # grab a hold of the TLS connection, in case it gets torn down + tls_connection = server_protocol._tlsConnection + else: + http_protocol = server_protocol + tls_connection = None - # give the reactor a pump to get the TLS juices flowing. - self.reactor.pump((0.1,)) + # give the reactor a pump to get the TLS juices flowing (if needed) + self.reactor.advance(0) # check the SNI - server_name = server_tls_connection.get_servername() - self.assertEqual( - server_name, - expected_sni, - "Expected SNI %s but got %s" % (expected_sni, server_name), - ) + if expected_sni is not None: + server_name = tls_connection.get_servername() + self.assertEqual( + server_name, + expected_sni, + f"Expected SNI {expected_sni!s} but got {server_name!s}", + ) return http_protocol @defer.inlineCallbacks - def _make_get_request(self, uri): + def _make_get_request(self, uri: bytes): """ Sends a simple GET request via the agent, and checks its logcontext management """ @@ -180,20 +183,20 @@ class MatrixFederationAgentTests(unittest.TestCase): def _handle_well_known_connection( self, - client_factory, - expected_sni, - content, + client_factory: IProtocolFactory, + expected_sni: bytes, + content: bytes, response_headers: Optional[dict] = None, - ): + ) -> HTTPChannel: """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. Args: - client_factory (IProtocolFactory): outgoing connection - expected_sni (bytes): SNI that we expect the outgoing connection to send - content (bytes): content to send back as the .well-known + client_factory: outgoing connection + expected_sni: SNI that we expect the outgoing connection to send + content: content to send back as the .well-known Returns: - HTTPChannel: server impl + server impl """ # make the connection for .well-known well_known_server = self._make_connection( @@ -209,7 +212,10 @@ class MatrixFederationAgentTests(unittest.TestCase): return well_known_server def _send_well_known_response( - self, request, content, headers: Optional[dict] = None + self, + request: Request, + content: bytes, + headers: Optional[dict] = None, ): """Check that an incoming request looks like a valid .well-known request, and send back the response. @@ -225,10 +231,37 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) - def test_get(self): + def _make_agent(self) -> MatrixFederationAgent: """ - happy-path test of a GET request with an explicit port + If a proxy server is set, the MatrixFederationAgent must be created again + because it is created too early during setUp """ + return MatrixFederationAgent( + reactor=self.reactor, + tls_client_options_factory=self.tls_factory, + user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. + ip_whitelist=IPSet(), + ip_blacklist=IPSet(), + _srv_resolver=self.mock_resolver, + _well_known_resolver=self.well_known_resolver, + ) + + def test_get(self): + """happy-path test of a GET request with an explicit port""" + self._do_get() + + @patch.dict( + os.environ, + {"https_proxy": "proxy.com", "no_proxy": "testserv"}, + ) + def test_get_bypass_proxy(self): + """test of a GET request with an explicit port and bypass proxy""" + self._do_get() + + def _do_get(self): + """test of a GET request with an explicit port""" + self.agent = self._make_agent() + self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") @@ -282,10 +315,188 @@ class MatrixFederationAgentTests(unittest.TestCase): json = self.successResultOf(treq.json_content(response)) self.assertEqual(json, {"a": 1}) + @patch.dict( + os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"} + ) + def test_get_via_http_proxy(self): + """test for federation request through a http proxy""" + self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"}, + ) + def test_get_via_http_proxy_with_auth(self): + """test for federation request through a http proxy with authentication""" + self._do_get_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=b"user:pass" + ) + + @patch.dict( + os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} + ) + def test_get_via_https_proxy(self): + """test for federation request through a https proxy""" + self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"}, + ) + def test_get_via_https_proxy_with_auth(self): + """test for federation request through a https proxy with authentication""" + self._do_get_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=b"user:pass" + ) + + def _do_get_via_proxy( + self, + expect_proxy_ssl: bool = False, + expected_auth_credentials: Optional[bytes] = None, + ): + """Send a https federation request via an agent and check that it is correctly + received at the proxy and client. The proxy can use either http or https. + Args: + expect_proxy_ssl: True if we expect the request to connect to the proxy via https. + expected_auth_credentials: credentials we expect to be presented to authenticate at the proxy + """ + self.agent = self._make_agent() + + self.reactor.lookups["testserv"] = "1.2.3.4" + self.reactor.lookups["proxy.com"] = "9.9.9.9" + test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + # make sure we are connecting to the proxy + self.assertEqual(host, "9.9.9.9") + self.assertEqual(port, 1080) + + # make a test server to act as the proxy, and wire up the client + proxy_server = self._make_connection( + client_factory, + ssl=expect_proxy_ssl, + tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, + expected_sni=b"proxy.com" if expect_proxy_ssl else None, + ) + + assert isinstance(proxy_server, HTTPChannel) + + # now there should be a pending CONNECT request + self.assertEqual(len(proxy_server.requests), 1) + + request = proxy_server.requests[0] + self.assertEqual(request.method, b"CONNECT") + self.assertEqual(request.path, b"testserv:8448") + + # Check whether auth credentials have been supplied to the proxy + proxy_auth_header_values = request.requestHeaders.getRawHeaders( + b"Proxy-Authorization" + ) + + if expected_auth_credentials is not None: + # Compute the correct header value for Proxy-Authorization + encoded_credentials = base64.b64encode(expected_auth_credentials) + expected_header_value = b"Basic " + encoded_credentials + + # Validate the header's value + self.assertIn(expected_header_value, proxy_auth_header_values) + else: + # Check that the Proxy-Authorization header has not been supplied to the proxy + self.assertIsNone(proxy_auth_header_values) + + # tell the proxy server not to close the connection + proxy_server.persistent = True + + request.finish() + + # now we make another test server to act as the upstream HTTP server. + server_ssl_protocol = _wrap_server_factory_for_tls( + _get_test_protocol_factory() + ).buildProtocol(None) + + # Tell the HTTP server to send outgoing traffic back via the proxy's transport. + proxy_server_transport = proxy_server.transport + server_ssl_protocol.makeConnection(proxy_server_transport) + + # ... and replace the protocol on the proxy's transport with the + # TLSMemoryBIOProtocol for the test server, so that incoming traffic + # to the proxy gets sent over to the HTTP(s) server. + + # See also comment at `_do_https_request_via_proxy` + # in ../test_proxyagent.py for more details + if expect_proxy_ssl: + assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol) + proxy_server_transport.wrappedProtocol = server_ssl_protocol + else: + assert isinstance(proxy_server_transport, FakeTransport) + client_protocol = proxy_server_transport.other + c2s_transport = client_protocol.transport + c2s_transport.other = server_ssl_protocol + + self.reactor.advance(0) + + server_name = server_ssl_protocol._tlsConnection.get_servername() + expected_sni = b"testserv" + self.assertEqual( + server_name, + expected_sni, + f"Expected SNI {expected_sni!s} but got {server_name!s}", + ) + + # now there should be a pending request + http_server = server_ssl_protocol.wrappedProtocol + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual( + request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"] + ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] + ) + # Check that the destination server DID NOT receive proxy credentials + self.assertIsNone(request.requestHeaders.getRawHeaders(b"Proxy-Authorization")) + content = request.content.read() + self.assertEqual(content, b"") + + # Deferred is still without a result + self.assertNoResult(test_d) + + # send the headers + request.responseHeaders.setRawHeaders(b"Content-Type", [b"application/json"]) + request.write("") + + self.reactor.pump((0.1,)) + + response = self.successResultOf(test_d) + + # that should give us a Response object + self.assertEqual(response.code, 200) + + # Send the body + request.write('{ "a": 1 }'.encode("ascii")) + request.finish() + + self.reactor.pump((0.1,)) + + # check it can be read + json = self.successResultOf(treq.json_content(response)) + self.assertEqual(json, {"a": 1}) + def test_get_ip_address(self): """ Test the behaviour when the server name contains an explicit IP (with no port) """ + self.agent = self._make_agent() + # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.4"] = "1.2.3.4" @@ -320,6 +531,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name contains an explicit IPv6 address (with no port) """ + self.agent = self._make_agent() # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" @@ -355,6 +567,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name contains an explicit IPv6 address (with explicit port) """ + self.agent = self._make_agent() # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" @@ -389,6 +602,8 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the certificate on the server doesn't match the hostname """ + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv1"] = "1.2.3.4" @@ -441,6 +656,8 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name contains an explicit IP, but the server cert doesn't cover it """ + self.agent = self._make_agent() + # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.5"] = "1.2.3.5" @@ -471,6 +688,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the server name has no port, no SRV, and no well-known """ + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -524,6 +742,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_get_well_known(self): """Test the behaviour when the .well-known delegates elsewhere""" + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -587,6 +806,8 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -675,6 +896,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -743,6 +965,7 @@ class MatrixFederationAgentTests(unittest.TestCase): reactor=self.reactor, tls_client_options_factory=tls_factory, user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. + ip_whitelist=IPSet(), ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( @@ -780,6 +1003,8 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when there is a single SRV record """ + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [Server(host=b"srvtarget", port=8443)] ) @@ -820,6 +1045,8 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the .well-known redirects to a place where there is a SRV. """ + self.agent = self._make_agent() + self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["srvtarget"] = "5.6.7.8" @@ -876,6 +1103,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_servername(self): """test the behaviour when the server name has idna chars in""" + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) @@ -937,6 +1165,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_srv_target(self): """test the behaviour when the target of a SRV record has idna chars""" + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com @@ -1140,6 +1369,8 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_srv_fallbacks(self): """Test that other SRV results are tried if the first one fails.""" + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [ Server(host=b"target.com", port=8443), @@ -1266,34 +1497,49 @@ def _check_logcontext(context): raise AssertionError("Expected logcontext %s but was %s" % (context, current)) -def _build_test_server(connection_creator): - """Construct a test server - - This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol - +def _wrap_server_factory_for_tls( + factory: IProtocolFactory, sanlist: Iterable[bytes] = None +) -> IProtocolFactory: + """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory + The resultant factory will create a TLS server which presents a certificate + signed by our test CA, valid for the domains in `sanlist` Args: - connection_creator (IOpenSSLServerConnectionCreator): thing to build - SSL connections - sanlist (list[bytes]): list of the SAN entries for the cert returned - by the server + factory: protocol factory to wrap + sanlist: list of domains the cert should be valid for + Returns: + interfaces.IProtocolFactory + """ + if sanlist is None: + sanlist = [ + b"DNS:testserv", + b"DNS:target-server", + b"DNS:xn--bcher-kva.com", + b"IP:1.2.3.4", + b"IP:::1", + ] + + connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) + return TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=factory + ) + +def _get_test_protocol_factory() -> IProtocolFactory: + """Get a protocol Factory which will build an HTTPChannel Returns: - TLSMemoryBIOProtocol + interfaces.IProtocolFactory """ server_factory = Factory.forProtocol(HTTPChannel) + # Request.finish expects the factory to have a 'log' method. server_factory.log = _log_request - server_tls_factory = TLSMemoryBIOFactory( - connection_creator, isClient=False, wrappedFactory=server_factory - ) - - return server_tls_factory.buildProtocol(None) + return server_factory -def _log_request(request): +def _log_request(request: str): """Implements Factory.log, which is expected by Request.finish""" - logger.info("Completed request %s", request) + logger.info(f"Completed request {request}") @implementer(IPolicyForHTTPS) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index e5865c161d..2db77c6a73 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -29,7 +29,8 @@ from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel from synapse.http.client import BlacklistingReactorWrapper -from synapse.http.proxyagent import ProxyAgent, ProxyCredentials, parse_proxy +from synapse.http.connectproxyclient import ProxyCredentials +from synapse.http.proxyagent import ProxyAgent, parse_proxy from tests.http import TestServerTLSConnectionFactory, get_test_https_policy from tests.server import FakeTransport, ThreadedMemoryReactorClock @@ -392,7 +393,9 @@ class MatrixFederationAgentTests(TestCase): """ Tests that requests can be made through a proxy. """ - self._do_http_request_via_proxy(ssl=False, auth_credentials=None) + self._do_http_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -402,13 +405,17 @@ class MatrixFederationAgentTests(TestCase): """ Tests that authenticated requests can be made through a proxy. """ - self._do_http_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies") + self._do_http_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" + ) @patch.dict( os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"} ) def test_http_request_via_https_proxy(self): - self._do_http_request_via_proxy(ssl=True, auth_credentials=None) + self._do_http_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -418,12 +425,16 @@ class MatrixFederationAgentTests(TestCase): }, ) def test_http_request_via_https_proxy_with_auth(self): - self._do_http_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies") + self._do_http_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" + ) @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) def test_https_request_via_proxy(self): """Tests that TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=False, auth_credentials=None) + self._do_https_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -431,14 +442,18 @@ class MatrixFederationAgentTests(TestCase): ) def test_https_request_via_proxy_with_auth(self): """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies") + self._do_https_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" + ) @patch.dict( os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} ) def test_https_request_via_https_proxy(self): """Tests that TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=True, auth_credentials=None) + self._do_https_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -446,20 +461,22 @@ class MatrixFederationAgentTests(TestCase): ) def test_https_request_via_https_proxy_with_auth(self): """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies") + self._do_https_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" + ) def _do_http_request_via_proxy( self, - ssl: bool = False, - auth_credentials: Optional[bytes] = None, + expect_proxy_ssl: bool = False, + expected_auth_credentials: Optional[bytes] = None, ): """Send a http request via an agent and check that it is correctly received at the proxy. The proxy can use either http or https. Args: - ssl: True if we expect the request to connect via https to proxy - auth_credentials: credentials to authenticate at proxy + expect_proxy_ssl: True if we expect the request to connect via https to proxy + expected_auth_credentials: credentials to authenticate at proxy """ - if ssl: + if expect_proxy_ssl: agent = ProxyAgent( self.reactor, use_proxy=True, contextFactory=get_test_https_policy() ) @@ -480,9 +497,9 @@ class MatrixFederationAgentTests(TestCase): http_server = self._make_connection( client_factory, _get_test_protocol_factory(), - ssl=ssl, - tls_sanlist=[b"DNS:proxy.com"] if ssl else None, - expected_sni=b"proxy.com" if ssl else None, + ssl=expect_proxy_ssl, + tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, + expected_sni=b"proxy.com" if expect_proxy_ssl else None, ) # the FakeTransport is async, so we need to pump the reactor @@ -498,9 +515,9 @@ class MatrixFederationAgentTests(TestCase): b"Proxy-Authorization" ) - if auth_credentials is not None: + if expected_auth_credentials is not None: # Compute the correct header value for Proxy-Authorization - encoded_credentials = base64.b64encode(auth_credentials) + encoded_credentials = base64.b64encode(expected_auth_credentials) expected_header_value = b"Basic " + encoded_credentials # Validate the header's value @@ -523,14 +540,14 @@ class MatrixFederationAgentTests(TestCase): def _do_https_request_via_proxy( self, - ssl: bool = False, - auth_credentials: Optional[bytes] = None, + expect_proxy_ssl: bool = False, + expected_auth_credentials: Optional[bytes] = None, ): """Send a https request via an agent and check that it is correctly received at the proxy and client. The proxy can use either http or https. Args: - ssl: True if we expect the request to connect via https to proxy - auth_credentials: credentials to authenticate at proxy + expect_proxy_ssl: True if we expect the request to connect via https to proxy + expected_auth_credentials: credentials to authenticate at proxy """ agent = ProxyAgent( self.reactor, @@ -552,9 +569,9 @@ class MatrixFederationAgentTests(TestCase): proxy_server = self._make_connection( client_factory, _get_test_protocol_factory(), - ssl=ssl, - tls_sanlist=[b"DNS:proxy.com"] if ssl else None, - expected_sni=b"proxy.com" if ssl else None, + ssl=expect_proxy_ssl, + tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, + expected_sni=b"proxy.com" if expect_proxy_ssl else None, ) assert isinstance(proxy_server, HTTPChannel) @@ -570,9 +587,9 @@ class MatrixFederationAgentTests(TestCase): b"Proxy-Authorization" ) - if auth_credentials is not None: + if expected_auth_credentials is not None: # Compute the correct header value for Proxy-Authorization - encoded_credentials = base64.b64encode(auth_credentials) + encoded_credentials = base64.b64encode(expected_auth_credentials) expected_header_value = b"Basic " + encoded_credentials # Validate the header's value @@ -606,7 +623,7 @@ class MatrixFederationAgentTests(TestCase): # Protocol to implement the proxy, which starts out by forwarding to an # HTTPChannel (to implement the CONNECT command) and can then be switched # into a mode where it forwards its traffic to another Protocol.) - if ssl: + if expect_proxy_ssl: assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol) proxy_server_transport.wrappedProtocol = server_ssl_protocol else: diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 81d9e2f484..7dd519cd44 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -20,7 +20,7 @@ from synapse.events import EventBase from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState from synapse.rest import admin -from synapse.rest.client.v1 import login, presence, room +from synapse.rest.client import login, presence, room from synapse.types import create_requester from tests.events.test_presence_router import send_presence_update, sync_presence @@ -79,6 +79,16 @@ class ModuleApiTestCase(HomeserverTestCase): displayname = self.get_success(self.store.get_profile_displayname("bob")) self.assertEqual(displayname, "Bobberino") + def test_get_userinfo_by_id(self): + user_id = self.register_user("alice", "1234") + found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + self.assertEqual(found_user.user_id.to_string(), user_id) + self.assertIdentical(found_user.is_admin, False) + + def test_get_userinfo_by_id__no_user_found(self): + found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test")) + self.assertIsNone(found_user) + def test_sending_events_into_room(self): """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent diff --git a/tests/push/test_email.py b/tests/push/test_email.py index e04bc5c9a6..fa8018e5a7 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import email.message import os +from typing import Dict, List, Sequence, Tuple import attr import pkg_resources @@ -21,7 +22,7 @@ from twisted.internet.defer import Deferred import synapse.rest.admin from synapse.api.errors import Codes, SynapseError -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.unittest import HomeserverTestCase @@ -45,14 +46,6 @@ class EmailPusherTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): - # List[Tuple[Deferred, args, kwargs]] - self.email_attempts = [] - - def sendmail(*args, **kwargs): - d = Deferred() - self.email_attempts.append((d, args, kwargs)) - return d - config = self.default_config() config["email"] = { "enable_notifs": True, @@ -75,7 +68,18 @@ class EmailPusherTests(HomeserverTestCase): config["public_baseurl"] = "aaa" config["start_pushers"] = True - hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + hs = self.setup_test_homeserver(config=config) + + # List[Tuple[Deferred, args, kwargs]] + self.email_attempts: List[Tuple[Deferred, Sequence, Dict]] = [] + + def sendmail(*args, **kwargs): + # This mocks out synapse.reactor.send_email._sendmail. + d = Deferred() + self.email_attempts.append((d, args, kwargs)) + return d + + hs.get_send_email_handler()._sendmail = sendmail return hs @@ -123,6 +127,8 @@ class EmailPusherTests(HomeserverTestCase): ) ) + self.auth_handler = hs.get_auth_handler() + def test_need_validated_email(self): """Test that we can only add an email pusher if the user has validated their email. @@ -251,6 +257,39 @@ class EmailPusherTests(HomeserverTestCase): # We should get emailed about those messages self._check_for_mail() + def test_room_notifications_include_avatar(self): + # Create a room and set its avatar. + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.helper.send_state( + room, "m.room.avatar", {"url": "mxc://DUMMY_MEDIA_ID"}, self.access_token + ) + + # Invite two other uses. + for other in self.others: + self.helper.invite( + room=room, src=self.user_id, tok=self.access_token, targ=other.id + ) + self.helper.join(room=room, user=other.id, tok=other.token) + + # The other users send some messages. + # TODO It seems that two messages are required to trigger an email? + self.helper.send(room, body="Alpha", tok=self.others[0].token) + self.helper.send(room, body="Beta", tok=self.others[1].token) + + # We should get emailed about those messages + args, kwargs = self._check_for_mail() + + # That email should contain the room's avatar + msg: bytes = args[5] + # Multipart: plain text, base 64 encoded; html, base 64 encoded + html = ( + email.message_from_bytes(msg) + .get_payload()[1] + .get_payload(decode=True) + .decode() + ) + self.assertIn("_matrix/media/v1/thumbnail/DUMMY_MEDIA_ID", html) + def test_empty_room(self): """All users leaving a room shouldn't cause the pusher to break.""" # Create a simple room with two users @@ -303,9 +342,95 @@ class EmailPusherTests(HomeserverTestCase): # We should get emailed about that message self._check_for_mail() - def _check_for_mail(self): - """Check that the user receives an email notification""" + def test_no_email_sent_after_removed(self): + # Create a simple room with two users + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.helper.invite( + room=room, + src=self.user_id, + tok=self.access_token, + targ=self.others[0].id, + ) + self.helper.join( + room=room, + user=self.others[0].id, + tok=self.others[0].token, + ) + + # The other user sends a single message. + self.helper.send(room, body="Hi!", tok=self.others[0].token) + + # We should get emailed about that message + self._check_for_mail() + # disassociate the user's email address + self.get_success( + self.auth_handler.delete_threepid( + user_id=self.user_id, + medium="email", + address="a@example.com", + ) + ) + + # check that the pusher for that email address has been deleted + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + ) + pushers = list(pushers) + self.assertEqual(len(pushers), 0) + + def test_remove_unlinked_pushers_background_job(self): + """Checks that all existing pushers associated with unlinked email addresses are removed + upon running the remove_deleted_email_pushers background update. + """ + # disassociate the user's email address manually (without deleting the pusher). + # This resembles the old behaviour, which the background update below is intended + # to clean up. + self.get_success( + self.hs.get_datastore().user_delete_threepid( + self.user_id, "email", "a@example.com" + ) + ) + + # Run the "remove_deleted_email_pushers" background job + self.get_success( + self.hs.get_datastore().db_pool.simple_insert( + table="background_updates", + values={ + "update_name": "remove_deleted_email_pushers", + "progress_json": "{}", + "depends_on": None, + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.hs.get_datastore().db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + while not self.get_success( + self.hs.get_datastore().db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.hs.get_datastore().db_pool.updates.do_next_background_update(100), + by=0.1, + ) + + # Check that all pushers with unlinked addresses were deleted + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + ) + pushers = list(pushers) + self.assertEqual(len(pushers), 0) + + def _check_for_mail(self) -> Tuple[Sequence, Dict]: + """ + Assert that synapse sent off exactly one email notification. + + Returns: + args and kwargs passed to synapse.reactor.send_email._sendmail for + that notification. + """ # Get the stream ordering before it gets sent pushers = self.get_success( self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) @@ -328,8 +453,9 @@ class EmailPusherTests(HomeserverTestCase): # One email was attempted to be sent self.assertEqual(len(self.email_attempts), 1) + deferred, sendmail_args, sendmail_kwargs = self.email_attempts[0] # Make the email succeed - self.email_attempts[0][0].callback(True) + deferred.callback(True) self.pump() # One email was attempted to be sent @@ -345,3 +471,4 @@ class EmailPusherTests(HomeserverTestCase): # Reset the attempts. self.email_attempts = [] + return sendmail_args, sendmail_kwargs diff --git a/tests/push/test_http.py b/tests/push/test_http.py index ffd75b1491..c068d329a9 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -18,8 +18,7 @@ from twisted.internet.defer import Deferred import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfigException -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import receipts +from synapse.rest.client import login, receipts, room from tests.unittest import HomeserverTestCase, override_config diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index db80a0bdbd..b25a06b427 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -20,7 +20,7 @@ from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.storage.roommember import RoomsForUser +from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.types import PersistedEventPosition from tests.server import FakeTransport @@ -150,6 +150,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): "invite", event.event_id, event.internal_metadata.stream_ordering, + RoomVersions.V1.identifier, ) ], ) @@ -216,7 +217,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_rooms_for_user_with_stream_ordering", (USER_ID_2,), - {(ROOM_ID, expected_pos)}, + {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, ) def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): @@ -305,7 +306,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): expected_pos = PersistedEventPosition( "master", j2.internal_metadata.stream_ordering ) - self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)}) + self.assertEqual( + joined_rooms, + {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, + ) event_id = 0 diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 666008425a..f198a94887 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -24,7 +24,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamRow, ) from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.replication._base import BaseStreamTestCase from tests.test_utils.event_injection import inject_event, inject_member_event diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py index 1346e0e160..43a16bb141 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from synapse.rest.client.v2_alpha import register +from synapse.rest.client import register from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import FakeChannel, make_request diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index b9751efdc5..995097d72c 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from synapse.rest.client.v2_alpha import register +from synapse.rest.client import register from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index a0c710f855..92a5b53e11 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -17,7 +17,7 @@ from unittest.mock import Mock from synapse.api.constants import EventTypes, Membership from synapse.events.builder import EventBuilderFactory from synapse.rest.admin import register_servlets_for_client_rest_resource -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -205,7 +205,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): def create_room_with_remote_server(self, user, token, remote_server="other_server"): room = self.helper.create_room_as(user, tok=token) store = self.hs.get_datastore() - federation = self.hs.get_federation_handler() + federation = self.hs.get_federation_event_handler() prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) room_version = self.get_success(store.get_room_version(room)) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index ffa425328f..ac419f0db3 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -22,7 +22,7 @@ from twisted.web.http import HTTPChannel from twisted.web.server import Request from synapse.rest import admin -from synapse.rest.client.v1 import login +from synapse.rest.client import login from synapse.server import HomeServer from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 1e4e3821b9..4094a75f36 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -17,7 +17,7 @@ from unittest.mock import Mock from twisted.internet import defer from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.replication._base import BaseMultiWorkerStreamTestCase diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index f3615af97e..0a6e4795ee 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -16,8 +16,7 @@ from unittest.mock import patch from synapse.api.room_versions import RoomVersion from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index a7c6e595b9..bfa638fb4b 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -24,8 +24,7 @@ import synapse.rest.admin from synapse.http.server import JsonResource from synapse.logging.context import make_deferred_yieldable from synapse.rest.admin import VersionServlet -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import groups +from synapse.rest.client import groups, login, room from tests import unittest from tests.server import FakeSite, make_request diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 120730b764..a3679be205 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import urllib.parse +from parameterized import parameterized + import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest @@ -45,49 +46,23 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.other_user_device_id, ) - def test_no_auth(self): + @parameterized.expand(["GET", "PUT", "DELETE"]) + def test_no_auth(self, method: str): """ Try to get a device of an user without authentication. """ - channel = self.make_request("GET", self.url, b"{}") - - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("PUT", self.url, b"{}") - - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("DELETE", self.url, b"{}") + channel = self.make_request(method, self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + @parameterized.expand(["GET", "PUT", "DELETE"]) + def test_requester_is_no_admin(self, method: str): """ If the user is not a server admin, an error is returned. """ channel = self.make_request( - "GET", - self.url, - access_token=self.other_user_token, - ) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "PUT", - self.url, - access_token=self.other_user_token, - ) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "DELETE", + method, self.url, access_token=self.other_user_token, ) @@ -95,7 +70,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + @parameterized.expand(["GET", "PUT", "DELETE"]) + def test_user_does_not_exist(self, method: str): """ Tests that a lookup for a user that does not exist returns a 404 """ @@ -105,7 +81,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) channel = self.make_request( - "GET", + method, url, access_token=self.admin_user_tok, ) @@ -113,25 +89,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - channel = self.make_request( - "PUT", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - channel = self.make_request( - "DELETE", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - def test_user_is_not_local(self): + @parameterized.expand(["GET", "PUT", "DELETE"]) + def test_user_is_not_local(self, method: str): """ Tests that a lookup for a user that is not a local returns a 400 """ @@ -141,25 +100,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) - - channel = self.make_request( - "PUT", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) - - channel = self.make_request( - "DELETE", + method, url, access_token=self.admin_user_tok, ) @@ -219,12 +160,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1) } - body = json.dumps(update) channel = self.make_request( "PUT", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=update, ) self.assertEqual(400, channel.code, msg=channel.json_body) @@ -275,12 +215,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): Tests a normal successful update of display name """ # Set new display_name - body = json.dumps({"display_name": "new displayname"}) channel = self.make_request( "PUT", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"display_name": "new displayname"}, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -529,12 +468,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ Tests that a remove of a device that does not exist returns 200. """ - body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]}) channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"devices": ["unknown_device1", "unknown_device2"]}, ) # Delete unknown devices returns status 200 @@ -560,12 +498,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): device_ids.append(str(d["device_id"])) # Delete devices - body = json.dumps({"devices": device_ids}) channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"devices": device_ids}, ) self.assertEqual(200, channel.code, msg=channel.json_body) diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index f15d1cf6f7..e9ef89731f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -16,8 +16,7 @@ import json import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import report_event +from synapse.rest.client import login, report_event, room from tests import unittest diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 7198fd293f..972d60570c 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -20,7 +20,7 @@ from parameterized import parameterized import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login, profile, room +from synapse.rest.client import login, profile, room from synapse.rest.media.v1.filepath import MediaFilePaths from tests import unittest diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py new file mode 100644 index 0000000000..4927321e5a --- /dev/null +++ b/tests/rest/admin/test_registration_tokens.py @@ -0,0 +1,710 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import string + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login + +from tests import unittest + + +class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.url = "/_synapse/admin/v1/registration_tokens" + + def _new_token(self, **kwargs): + """Helper function to create a token.""" + token = kwargs.get( + "token", + "".join(random.choices(string.ascii_letters, k=8)), + ) + self.get_success( + self.store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": kwargs.get("uses_allowed", None), + "pending": kwargs.get("pending", 0), + "completed": kwargs.get("completed", 0), + "expiry_time": kwargs.get("expiry_time", None), + }, + ) + ) + return token + + # CREATION + + def test_create_no_auth(self): + """Try to create a token without authentication.""" + channel = self.make_request("POST", self.url + "/new", {}) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_create_requester_not_admin(self): + """Try to create a token while not an admin.""" + channel = self.make_request( + "POST", + self.url + "/new", + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_create_using_defaults(self): + """Create a token using all the defaults.""" + channel = self.make_request( + "POST", + self.url + "/new", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 16) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_specifying_fields(self): + """Create a token specifying the value of all fields.""" + data = { + "token": "abcd", + "uses_allowed": 1, + "expiry_time": self.clock.time_msec() + 1000000, + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["token"], "abcd") + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_with_null_value(self): + """Create a token specifying unlimited uses and no expiry.""" + data = { + "uses_allowed": None, + "expiry_time": None, + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 16) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_token_too_long(self): + """Check token longer than 64 chars is invalid.""" + data = {"token": "a" * 65} + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_token_invalid_chars(self): + """Check you can't create token with invalid characters.""" + data = { + "token": "abc/def", + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_token_already_exists(self): + """Check you can't create token that already exists.""" + data = { + "token": "abcd", + } + + channel1 = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"]) + + channel2 = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"]) + self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_unable_to_generate_token(self): + """Check right error is raised when server can't generate unique token.""" + # Create all possible single character tokens + tokens = [] + for c in string.ascii_letters + string.digits + "-_": + tokens.append( + { + "token": c, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + } + ) + self.get_success( + self.store.db_pool.simple_insert_many( + "registration_tokens", + tokens, + "create_all_registration_tokens", + ) + ) + + # Check creating a single character token fails with a 500 status code + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 1}, + access_token=self.admin_user_tok, + ) + self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"]) + + def test_create_uses_allowed(self): + """Check you can only create a token with good values for uses_allowed.""" + # Should work with 0 (token is invalid from the start) + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 0) + + # Should fail with negative integer + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with float + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": 1.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_expiry_time(self): + """Check you can't create a token with an invalid expiry_time.""" + # Should fail with a time in the past + channel = self.make_request( + "POST", + self.url + "/new", + {"expiry_time": self.clock.time_msec() - 10000}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with float + channel = self.make_request( + "POST", + self.url + "/new", + {"expiry_time": self.clock.time_msec() + 1000000.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_length(self): + """Check you can only generate a token with a valid length.""" + # Should work with 64 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 64}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 64) + + # Should fail with 0 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a negative integer + channel = self.make_request( + "POST", + self.url + "/new", + {"length": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a float + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 8.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with 65 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 65}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # UPDATING + + def test_update_no_auth(self): + """Try to update a token without authentication.""" + channel = self.make_request( + "PUT", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_update_requester_not_admin(self): + """Try to update a token while not an admin.""" + channel = self.make_request( + "PUT", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_update_non_existent(self): + """Try to update a token that doesn't exist.""" + channel = self.make_request( + "PUT", + self.url + "/1234", + {"uses_allowed": 1}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_update_uses_allowed(self): + """Test updating just uses_allowed.""" + # Create new token using default values + token = self._new_token() + + # Should succeed with 1 + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 1}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should succeed with 0 (makes token invalid) + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 0) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should succeed with null + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": None}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should fail with a float + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 1.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a negative integer + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_update_expiry_time(self): + """Test updating just expiry_time.""" + # Create new token using default values + token = self._new_token() + new_expiry_time = self.clock.time_msec() + 1000000 + + # Should succeed with a time in the future + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": new_expiry_time}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) + self.assertIsNone(channel.json_body["uses_allowed"]) + + # Should succeed with null + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": None}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertIsNone(channel.json_body["uses_allowed"]) + + # Should fail with a time in the past + past_time = self.clock.time_msec() - 10000 + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": past_time}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail a float + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": new_expiry_time + 0.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_update_both(self): + """Test updating both uses_allowed and expiry_time.""" + # Create new token using default values + token = self._new_token() + new_expiry_time = self.clock.time_msec() + 1000000 + + data = { + "uses_allowed": 1, + "expiry_time": new_expiry_time, + } + + channel = self.make_request( + "PUT", + self.url + "/" + token, + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) + + def test_update_invalid_type(self): + """Test using invalid types doesn't work.""" + # Create new token using default values + token = self._new_token() + + data = { + "uses_allowed": False, + "expiry_time": "1626430124000", + } + + channel = self.make_request( + "PUT", + self.url + "/" + token, + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # DELETING + + def test_delete_no_auth(self): + """Try to delete a token without authentication.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_delete_requester_not_admin(self): + """Try to delete a token while not an admin.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_delete_non_existent(self): + """Try to delete a token that doesn't exist.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_delete(self): + """Test deleting a token.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "DELETE", + self.url + "/" + token, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # GETTING ONE + + def test_get_no_auth(self): + """Try to get a token without authentication.""" + channel = self.make_request( + "GET", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_get_requester_not_admin(self): + """Try to get a token while not an admin.""" + channel = self.make_request( + "GET", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_get_non_existent(self): + """Try to get a token that doesn't exist.""" + channel = self.make_request( + "GET", + self.url + "/1234", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_get(self): + """Test getting a token.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "GET", + self.url + "/" + token, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["token"], token) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + # LISTING + + def test_list_no_auth(self): + """Try to list tokens without authentication.""" + channel = self.make_request("GET", self.url, {}) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_list_requester_not_admin(self): + """Try to list tokens while not an admin.""" + channel = self.make_request( + "GET", + self.url, + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_list_all(self): + """Test listing all tokens.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "GET", + self.url, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["registration_tokens"]), 1) + token_info = channel.json_body["registration_tokens"][0] + self.assertEqual(token_info["token"], token) + self.assertIsNone(token_info["uses_allowed"]) + self.assertIsNone(token_info["expiry_time"]) + self.assertEqual(token_info["pending"], 0) + self.assertEqual(token_info["completed"], 0) + + def test_list_invalid_query_parameter(self): + """Test with `valid` query parameter not `true` or `false`.""" + channel = self.make_request( + "GET", + self.url + "?valid=x", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + def _test_list_query_parameter(self, valid: str): + """Helper used to test both valid=true and valid=false.""" + # Create 2 valid and 2 invalid tokens. + now = self.hs.get_clock().time_msec() + # Create always valid token + valid1 = self._new_token() + # Create token that hasn't been used up + valid2 = self._new_token(uses_allowed=1) + # Create token that has expired + invalid1 = self._new_token(expiry_time=now - 10000) + # Create token that has been used up but hasn't expired + invalid2 = self._new_token( + uses_allowed=2, + pending=1, + completed=1, + expiry_time=now + 1000000, + ) + + if valid == "true": + tokens = [valid1, valid2] + else: + tokens = [invalid1, invalid2] + + channel = self.make_request( + "GET", + self.url + "?valid=" + valid, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["registration_tokens"]), 2) + token_info_1 = channel.json_body["registration_tokens"][0] + token_info_2 = channel.json_body["registration_tokens"][1] + self.assertIn(token_info_1["token"], tokens) + self.assertIn(token_info_2["token"], tokens) + + def test_list_valid(self): + """Test listing just valid tokens.""" + self._test_list_query_parameter(valid="true") + + def test_list_invalid(self): + """Test listing just invalid tokens.""" + self._test_list_query_parameter(valid="false") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 17ec8bfd3b..40e032df7f 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -22,130 +22,13 @@ from parameterized import parameterized_class import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes -from synapse.rest.client.v1 import directory, events, login, room +from synapse.rest.client import directory, events, login, room from tests import unittest """Tests admin REST events for /rooms paths.""" -class ShutdownRoomTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - events.register_servlets, - room.register_servlets, - room.register_deprecated_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.event_creation_handler = hs.get_event_creation_handler() - hs.config.user_consent_version = "1" - - consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" - self.event_creation_handler._consent_uri_builder = consent_uri_builder - - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.other_user = self.register_user("user", "pass") - self.other_user_token = self.login("user", "pass") - - # Mark the admin user as having consented - self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) - - def test_shutdown_room_consent(self): - """Test that we can shutdown rooms with local users who have not - yet accepted the privacy policy. This used to fail when we tried to - force part the user from the old room. - """ - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Assert one user in room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([self.other_user], users_in_room) - - # Enable require consent to send events - self.event_creation_handler._block_events_without_consent_error = "Error" - - # Assert that the user is getting consent error - self.helper.send( - room_id, body="foo", tok=self.other_user_token, expect_code=403 - ) - - # Test that the admin can still send shutdown - url = "/_synapse/admin/v1/shutdown_room/" + room_id - channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert there is now no longer anyone in the room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([], users_in_room) - - def test_shutdown_room_block_peek(self): - """Test that a world_readable room can no longer be peeked into after - it has been shut down. - """ - - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Enable world readable - url = "rooms/%s/state/m.room.history_visibility" % (room_id,) - channel = self.make_request( - "PUT", - url.encode("ascii"), - json.dumps({"history_visibility": "world_readable"}), - access_token=self.other_user_token, - ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the admin can still send shutdown - url = "/_synapse/admin/v1/shutdown_room/" + room_id - channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert we can no longer peek into the room - self._assert_peek(room_id, expect_code=403) - - def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room.""" - - url = "rooms/%s/initialSync" % (room_id,) - channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - url = "events?timeout=0&room_id=" + room_id - channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - @parameterized_class( ("method", "url_template"), [ @@ -557,51 +440,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) -class PurgeRoomTestCase(unittest.HomeserverTestCase): - """Test /purge_room admin API.""" - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - def test_purge_room(self): - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # All users have to have left the room. - self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) - - url = "/_synapse/admin/v1/purge_room" - channel = self.make_request( - "POST", - url.encode("ascii"), - {"room_id": room_id}, - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the following tables have been purged of all rows related to the room. - for table in PURGE_TABLES: - count = self.get_success( - self.store.db_pool.simple_select_one_onecol( - table=table, - keyvalues={"room_id": room_id}, - retcol="COUNT(*)", - desc="test_purge_room", - ) - ) - - self.assertEqual(count, 0, msg=f"Rows not purged in {table}") - - class RoomTestCase(unittest.HomeserverTestCase): """Test /room admin API.""" diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py new file mode 100644 index 0000000000..fbceba3254 --- /dev/null +++ b/tests/rest/admin/test_server_notice.py @@ -0,0 +1,450 @@ +# Copyright 2021 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login, room, sync +from synapse.storage.roommember import RoomsForUser +from synapse.types import JsonDict + +from tests import unittest +from tests.unittest import override_config + + +class ServerNoticeTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.room_shutdown_handler = hs.get_room_shutdown_handler() + self.pagination_handler = hs.get_pagination_handler() + self.server_notices_manager = self.hs.get_server_notices_manager() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + self.url = "/_synapse/admin/v1/send_server_notice" + + def test_no_auth(self): + """Try to send a server notice without authentication.""" + channel = self.make_request("POST", self.url) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """If the user is not a server admin, an error is returned.""" + channel = self.make_request( + "POST", + self.url, + access_token=self.other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_user_does_not_exist(self): + """Tests that a lookup for a user that does not exist returns a 404""" + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": "@unknown_person:test", "content": ""}, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": "@unknown_person:unknown_domain", + "content": "", + }, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual( + "Server notices can only be sent to local users", channel.json_body["error"] + ) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_invalid_parameter(self): + """If parameters are invalid, an error is returned.""" + + # no content, no user + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) + + # no content + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": self.other_user}, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # no body + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": self.other_user, "content": ""}, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual("'body' not in content", channel.json_body["error"]) + + # no msgtype + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": self.other_user, "content": {"body": ""}}, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual("'msgtype' not in content", channel.json_body["error"]) + + def test_server_notice_disabled(self): + """Tests that server returns error if server notice is disabled""" + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": "", + }, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual( + "Server notices are not enabled on this server", channel.json_body["error"] + ) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_send_server_notice(self): + """ + Tests that sending two server notices is successfully, + the server uses the same room and do not send messages twice. + """ + # user has no room memberships + self._check_invite_and_join_status(self.other_user, 0, 0) + + # send first message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg one"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join(room=room_id, user=self.other_user, tok=self.other_user_token) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(room_id, self.other_user_token) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + + # invalidate cache of server notices room_ids + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) + + # send second message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg two"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has no new invites or memberships + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(room_id, self.other_user_token) + + self.assertEqual(len(messages), 2) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + self.assertEqual(messages[1]["content"]["body"], "test msg two") + self.assertEqual(messages[1]["sender"], "@notices:test") + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_send_server_notice_leave_room(self): + """ + Tests that sending a server notices is successfully. + The user leaves the room and the second message appears + in a new room. + """ + # user has no room memberships + self._check_invite_and_join_status(self.other_user, 0, 0) + + # send first message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg one"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + first_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=first_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(first_room_id, self.other_user_token) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + + # user leaves the romm + self.helper.leave( + room=first_room_id, user=self.other_user, tok=self.other_user_token + ) + + # user is not member anymore + self._check_invite_and_join_status(self.other_user, 0, 0) + + # invalidate cache of server notices room_ids + # if server tries to send to a cached room_id the user gets the message + # in old room + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) + + # send second message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg two"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + second_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=second_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(second_room_id, self.other_user_token) + + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg two") + self.assertEqual(messages[0]["sender"], "@notices:test") + # room has the same id + self.assertNotEqual(first_room_id, second_room_id) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_send_server_notice_delete_room(self): + """ + Tests that the user get server notice in a new room + after the first server notice room was deleted. + """ + # user has no room memberships + self._check_invite_and_join_status(self.other_user, 0, 0) + + # send first message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg one"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + first_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=first_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(first_room_id, self.other_user_token) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + + # shut down and purge room + self.get_success( + self.room_shutdown_handler.shutdown_room(first_room_id, self.admin_user) + ) + self.get_success(self.pagination_handler.purge_room(first_room_id)) + + # user is not member anymore + self._check_invite_and_join_status(self.other_user, 0, 0) + + # It doesn't really matter what API we use here, we just want to assert + # that the room doesn't exist. + summary = self.get_success(self.store.get_room_summary(first_room_id)) + # The summary should be empty since the room doesn't exist. + self.assertEqual(summary, {}) + + # invalidate cache of server notices room_ids + # if server tries to send to a cached room_id it gives an error + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) + + # send second message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg two"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + second_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=second_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get message + messages = self._sync_and_get_messages(second_room_id, self.other_user_token) + + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg two") + self.assertEqual(messages[0]["sender"], "@notices:test") + # second room has new ID + self.assertNotEqual(first_room_id, second_room_id) + + def _check_invite_and_join_status( + self, user_id: str, expected_invites: int, expected_memberships: int + ) -> RoomsForUser: + """Check invite and room membership status of a user. + + Args + user_id: user to check + expected_invites: number of expected invites of this user + expected_memberships: number of expected room memberships of this user + Returns + room_ids from the rooms that the user is invited + """ + + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(user_id) + ) + self.assertEqual(expected_invites, len(invited_rooms)) + + room_ids = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(expected_memberships, len(room_ids)) + + return invited_rooms + + def _sync_and_get_messages(self, room_id: str, token: str) -> List[JsonDict]: + """ + Do a sync and get messages of a room. + + Args + room_id: room that contains the messages + token: access token of user + + Returns + list of messages contained in the room + """ + channel = self.make_request( + "GET", "/_matrix/client/r0/sync", access_token=token + ) + self.assertEqual(channel.code, 200) + + # Get the messages + room = channel.json_body["rooms"]["join"][room_id] + messages = [ + x for x in room["timeline"]["events"] if x["type"] == "m.room.message" + ] + return messages diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 79cac4266b..5cd82209c4 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 42f50c0921..ee204c404b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -15,17 +15,20 @@ import hashlib import hmac import json +import os import urllib.parse from binascii import unhexlify from typing import List, Optional from unittest.mock import Mock, patch +from parameterized import parameterized + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions -from synapse.rest.client.v1 import login, logout, profile, room -from synapse.rest.client.v2_alpha import devices, sync +from synapse.rest.client import devices, login, logout, profile, room, sync +from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.types import JsonDict, UserID from tests import unittest @@ -72,7 +75,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "Shared secret registration is not enabled", channel.json_body["error"] ) @@ -104,7 +107,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps({"nonce": nonce}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # 61 seconds @@ -112,7 +115,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_register_incorrect_nonce(self): @@ -166,7 +169,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) def test_nonce_reuse(self): @@ -191,13 +194,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_missing_parts(self): @@ -219,7 +222,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps({}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("nonce must be specified", channel.json_body["error"]) # @@ -230,28 +233,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps({"nonce": nonce()}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # @@ -262,28 +265,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps({"nonce": nonce(), "username": "a"}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # @@ -301,7 +304,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) def test_displayname(self): @@ -322,11 +325,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob1:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob1:test/displayname") - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob1", channel.json_body["displayname"]) # displayname is None @@ -348,11 +351,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob2:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob2:test/displayname") - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob2", channel.json_body["displayname"]) # displayname is empty @@ -374,7 +377,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") @@ -399,11 +402,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob4:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob4:test/displayname") - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("Bob's Name", channel.json_body["displayname"]) @override_config( @@ -449,7 +452,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -638,7 +641,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order @@ -1085,7 +1088,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": False}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1236,56 +1239,114 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) - def test_get_user(self): + def test_invalid_parameter(self): """ - Test a simple get of a user. + If parameters are invalid, an error is returned. """ + + # admin not bool channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"admin": "not_bool"}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual("User", channel.json_body["displayname"]) - self._check_fields(channel.json_body) + # deactivated not bool + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"deactivated": "not_bool"}, + ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - def test_get_user_with_sso(self): - """ - Test get a user with SSO details. - """ - self.get_success( - self.store.record_user_external_id( - "auth_provider1", "external_id1", self.other_user - ) + # password not str + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"password": True}, ) - self.get_success( - self.store.record_user_external_id( - "auth_provider2", "external_id2", self.other_user - ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # password not length + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"password": "x" * 513}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + # user_type not valid channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"user_type": "new type"}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual( - "external_id1", channel.json_body["external_ids"][0]["external_id"] + # external_ids not valid + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": {"auth_provider": "prov", "wrong_external_id": "id"} + }, ) - self.assertEqual( - "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"external_ids": {"external_id": "id"}}, ) - self.assertEqual( - "external_id2", channel.json_body["external_ids"][1]["external_id"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # threepids not valid + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": {"medium": "email", "wrong_address": "id"}}, ) - self.assertEqual( - "auth_provider2", channel.json_body["external_ids"][1]["auth_provider"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": {"address": "value"}}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + def test_get_user(self): + """ + Test a simple get of a user. + """ + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("User", channel.json_body["displayname"]) self._check_fields(channel.json_body) def test_create_server_admin(self): @@ -1349,6 +1410,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): "admin": False, "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider1", + }, + ], "avatar_url": "mxc://fibble/wibble", } @@ -1364,6 +1431,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(1, len(channel.json_body["threepids"])) + self.assertEqual( + "external_id1", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] + ) + self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self._check_fields(channel.json_body) @@ -1603,18 +1678,53 @@ class UserRestTestCase(unittest.HomeserverTestCase): Test setting threepid for an other user. """ - # Delete old and add new threepid to user + # Add two threepids to user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "threepids": [ + {"medium": "email", "address": "bob1@bob.bob"}, + {"medium": "email", "address": "bob2@bob.bob"}, + ], + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) + # result does not always have the same sort order, therefore it becomes sorted + sorted_result = sorted( + channel.json_body["threepids"], key=lambda k: k["address"] + ) + self.assertEqual("email", sorted_result[0]["medium"]) + self.assertEqual("bob1@bob.bob", sorted_result[0]["address"]) + self.assertEqual("email", sorted_result[1]["medium"]) + self.assertEqual("bob2@bob.bob", sorted_result[1]["address"]) + self._check_fields(channel.json_body) + + # Set a new and remove a threepid channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}, + content={ + "threepids": [ + {"medium": "email", "address": "bob2@bob.bob"}, + {"medium": "email", "address": "bob3@bob.bob"}, + ], + }, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("email", channel.json_body["threepids"][1]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"]) + self._check_fields(channel.json_body) # Get user channel = self.make_request( @@ -1625,8 +1735,122 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("email", channel.json_body["threepids"][1]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"]) + self._check_fields(channel.json_body) + + # Remove threepids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": []}, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self._check_fields(channel.json_body) + + def test_set_external_id(self): + """ + Test setting external id for an other user. + """ + + # Add two external_ids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider1", + }, + { + "external_id": "external_id2", + "auth_provider": "auth_provider2", + }, + ] + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) + # result does not always have the same sort order, therefore it becomes sorted + self.assertEqual( + sorted(channel.json_body["external_ids"], key=lambda k: k["auth_provider"]), + [ + {"auth_provider": "auth_provider1", "external_id": "external_id1"}, + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + ], + ) + self._check_fields(channel.json_body) + + # Set a new and remove an external_id + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id2", + "auth_provider": "auth_provider2", + }, + { + "external_id": "external_id3", + "auth_provider": "auth_provider3", + }, + ] + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) + self.assertEqual( + channel.json_body["external_ids"], + [ + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + {"auth_provider": "auth_provider3", "external_id": "external_id3"}, + ], + ) + self._check_fields(channel.json_body) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) + self.assertEqual( + channel.json_body["external_ids"], + [ + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + {"auth_provider": "auth_provider3", "external_id": "external_id3"}, + ], + ) + self._check_fields(channel.json_body) + + # Remove external_ids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"external_ids": []}, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(0, len(channel.json_body["external_ids"])) def test_deactivate_user(self): """ @@ -2180,7 +2404,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_get_pushers(self): """ @@ -2249,6 +2473,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() + self.filepaths = MediaFilePaths(hs.config.media_store_path) self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2258,37 +2483,34 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): - """ - Try to list media of an user without authentication. - """ - channel = self.make_request("GET", self.url, b"{}") + @parameterized.expand(["GET", "DELETE"]) + def test_no_auth(self, method: str): + """Try to list media of an user without authentication.""" + channel = self.make_request(method, self.url, {}) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): - """ - If the user is not a server admin, an error is returned. - """ + @parameterized.expand(["GET", "DELETE"]) + def test_requester_is_no_admin(self, method: str): + """If the user is not a server admin, an error is returned.""" other_user_token = self.login("user", "pass") channel = self.make_request( - "GET", + method, self.url, access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): - """ - Tests that a lookup for a user that does not exist returns a 404 - """ + @parameterized.expand(["GET", "DELETE"]) + def test_user_does_not_exist(self, method: str): + """Tests that a lookup for a user that does not exist returns a 404""" url = "/_synapse/admin/v1/users/@unknown_person:test/media" channel = self.make_request( - "GET", + method, url, access_token=self.admin_user_tok, ) @@ -2296,25 +2518,22 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): - """ - Tests that a lookup for a user that is not a local returns a 400 - """ + @parameterized.expand(["GET", "DELETE"]) + def test_user_is_not_local(self, method: str): + """Tests that a lookup for a user that is not a local returns a 400""" url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" channel = self.make_request( - "GET", + method, url, access_token=self.admin_user_tok, ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual("Can only look up local users", channel.json_body["error"]) - def test_limit(self): - """ - Testing list of media with limit - """ + def test_limit_GET(self): + """Testing list of media with limit""" number_media = 20 other_user_tok = self.login("user", "pass") @@ -2326,16 +2545,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 5) self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["media"]) - def test_from(self): - """ - Testing list of media with a defined starting point (from) - """ + def test_limit_DELETE(self): + """Testing delete of media with limit""" + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media_for_user(other_user_tok, number_media) + + channel = self.make_request( + "DELETE", + self.url + "?limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 5) + self.assertEqual(len(channel.json_body["deleted_media"]), 5) + + def test_from_GET(self): + """Testing list of media with a defined starting point (from)""" number_media = 20 other_user_tok = self.login("user", "pass") @@ -2347,16 +2581,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 15) self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["media"]) - def test_limit_and_from(self): - """ - Testing list of media with a defined starting point and limit - """ + def test_from_DELETE(self): + """Testing delete of media with a defined starting point (from)""" + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media_for_user(other_user_tok, number_media) + + channel = self.make_request( + "DELETE", + self.url + "?from=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 15) + self.assertEqual(len(channel.json_body["deleted_media"]), 15) + + def test_limit_and_from_GET(self): + """Testing list of media with a defined starting point and limit""" number_media = 20 other_user_tok = self.login("user", "pass") @@ -2368,59 +2617,78 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["media"]), 10) self._check_fields(channel.json_body["media"]) - def test_invalid_parameter(self): - """ - If parameters are invalid, an error is returned. - """ + def test_limit_and_from_DELETE(self): + """Testing delete of media with a defined starting point and limit""" + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media_for_user(other_user_tok, number_media) + + channel = self.make_request( + "DELETE", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 10) + self.assertEqual(len(channel.json_body["deleted_media"]), 10) + + @parameterized.expand(["GET", "DELETE"]) + def test_invalid_parameter(self, method: str): + """If parameters are invalid, an error is returned.""" # unkown order_by channel = self.make_request( - "GET", + method, self.url + "?order_by=bar", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order channel = self.make_request( - "GET", + method, self.url + "?dir=bar", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # negative limit channel = self.make_request( - "GET", + method, self.url + "?limit=-5", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from channel = self.make_request( - "GET", + method, self.url + "?from=-5", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_next_token(self): """ Testing that `next_token` appears at the right place + + For deleting media `next_token` is not useful, because + after deleting media the media has a new order. """ number_media = 20 @@ -2435,7 +2703,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2448,7 +2716,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2461,7 +2729,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -2475,12 +2743,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 1) self.assertNotIn("next_token", channel.json_body) - def test_user_has_no_media(self): + def test_user_has_no_media_GET(self): """ Tests that a normal lookup for media is successfully if user has no media created @@ -2496,11 +2764,24 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["media"])) - def test_get_media(self): + def test_user_has_no_media_DELETE(self): """ - Tests that a normal lookup for media is successfully + Tests that a delete is successful if user has no media """ + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["deleted_media"])) + + def test_get_media(self): + """Tests that a normal lookup for media is successful""" + number_media = 5 other_user_tok = self.login("user", "pass") self._create_media_for_user(other_user_tok, number_media) @@ -2517,6 +2798,35 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["media"]) + def test_delete_media(self): + """Tests that a normal delete of media is successful""" + + number_media = 5 + other_user_tok = self.login("user", "pass") + media_ids = self._create_media_for_user(other_user_tok, number_media) + + # Test if the file exists + local_paths = [] + for media_id in media_ids: + local_path = self.filepaths.local_media_filepath(media_id) + self.assertTrue(os.path.exists(local_path)) + local_paths.append(local_path) + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_media, channel.json_body["total"]) + self.assertEqual(number_media, len(channel.json_body["deleted_media"])) + self.assertCountEqual(channel.json_body["deleted_media"], media_ids) + + # Test if the file is deleted + for local_path in local_paths: + self.assertFalse(os.path.exists(local_path)) + def test_order_by(self): """ Testing order list with parameter `order_by` @@ -2622,13 +2932,16 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): [media2] + sorted([media1, media3]), "safe_from_quarantine", "b" ) - def _create_media_for_user(self, user_token: str, number_media: int): + def _create_media_for_user(self, user_token: str, number_media: int) -> List[str]: """ Create a number of media for a specific user Args: user_token: Access token of the user number_media: Number of media to be created for the user + Returns: + List of created media ID """ + media_ids = [] for _ in range(number_media): # file size is 67 Byte image_data = unhexlify( @@ -2637,7 +2950,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): b"0a2db40000000049454e44ae426082" ) - self._create_media_and_access(user_token, image_data) + media_ids.append(self._create_media_and_access(user_token, image_data)) + + return media_ids def _create_media_and_access( self, @@ -2680,7 +2995,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): 200, channel.code, msg=( - "Expected to receive a 200 on accessing media: %s" % server_and_media_id + f"Expected to receive a 200 on accessing media: {server_and_media_id}" ), ) @@ -2718,12 +3033,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): url = self.url + "?" if order_by is not None: - url += "order_by=%s&" % (order_by,) + url += f"order_by={order_by}&" if dir is not None and dir in ("b", "f"): - url += "dir=%s" % (dir,) + url += f"dir={dir}" channel = self.make_request( "GET", - url.encode("ascii"), + url, access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -2762,7 +3077,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, b"{}", access_token=self.admin_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) return channel.json_body["access_token"] def test_no_auth(self): @@ -2803,7 +3118,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # We should only see the one device (from the login in `prepare`) self.assertEqual(len(channel.json_body["devices"]), 1) @@ -2815,11 +3130,11 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout with the puppet token channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) @@ -2829,7 +3144,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_user_logout_all(self): """Tests that the target user calling `/logout/all` does *not* expire @@ -2840,17 +3155,17 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the real user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should still work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # .. but the real user's tokens shouldn't channel = self.make_request( @@ -2867,13 +3182,13 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the admin user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.admin_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) @@ -2883,7 +3198,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) @unittest.override_config( { @@ -3243,7 +3558,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual("Can only look up local users", channel.json_body["error"]) channel = self.make_request( "POST", @@ -3279,7 +3594,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": "string"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # messages_per_second is negative @@ -3290,7 +3605,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": -1}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is a string @@ -3301,7 +3616,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": "string"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is negative @@ -3312,7 +3627,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": -1}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_return_zero_when_null(self): @@ -3337,7 +3652,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["messages_per_second"]) self.assertEqual(0, channel.json_body["burst_count"]) @@ -3351,7 +3666,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3362,7 +3677,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 10, "burst_count": 11}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(10, channel.json_body["messages_per_second"]) self.assertEqual(11, channel.json_body["burst_count"]) @@ -3373,7 +3688,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 20, "burst_count": 21}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3383,7 +3698,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3393,7 +3708,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3403,6 +3718,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py new file mode 100644 index 0000000000..4e1c49c28b --- /dev/null +++ b/tests/rest/admin/test_username_available.py @@ -0,0 +1,62 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse.rest.admin +from synapse.api.errors import Codes, SynapseError +from synapse.rest.client import login + +from tests import unittest + + +class UsernameAvailableTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + url = "/_synapse/admin/v1/username_available" + + def prepare(self, reactor, clock, hs): + self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + async def check_username(username): + if username == "allowed": + return True + raise SynapseError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) + + handler = self.hs.get_registration_handler() + handler.check_username = check_username + + def test_username_available(self): + """ + The endpoint should return a 200 response if the username does not exist + """ + + url = "%s?username=%s" % (self.url, "allowed") + channel = self.make_request("GET", url, None, self.admin_user_tok) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertTrue(channel.json_body["available"]) + + def test_username_unavailable(self): + """ + The endpoint should return a 200 response if the username does not exist + """ + + url = "%s?username=%s" % (self.url, "disallowed") + channel = self.make_request("GET", url, None, self.admin_user_tok) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE") + self.assertEqual(channel.json_body["error"], "User ID already taken.") diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/test_account.py index 317a2287e3..b946fca8b3 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/test_account.py @@ -25,8 +25,7 @@ import synapse.rest.admin from synapse.api.constants import LoginType, Membership from synapse.api.errors import Codes, HttpResponseException from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import account, register +from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from tests import unittest @@ -47,12 +46,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): config = self.default_config() # Email config. - self.email_attempts = [] - - async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - return - config["email"] = { "enable_notifs": False, "template_dir": os.path.abspath( @@ -67,7 +60,16 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "https://example.com" - hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + hs.get_send_email_handler()._sendmail = sendmail + return hs def prepare(self, reactor, clock, hs): @@ -511,11 +513,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): config = self.default_config() # Email config. - self.email_attempts = [] - - async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - config["email"] = { "enable_notifs": False, "template_dir": os.path.abspath( @@ -530,7 +527,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "https://example.com" - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail + return self.hs def prepare(self, reactor, clock, hs): diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/test_auth.py index 6b90f838b6..e2fcbdc63a 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -19,14 +19,13 @@ from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import account, auth, devices, register +from synapse.rest.client import account, auth, devices, login, register from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import JsonDict, UserID from tests import unittest from tests.handlers.test_oidc import HAS_OIDC -from tests.rest.client.v1.utils import TEST_OIDC_CONFIG +from tests.rest.client.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel from tests.unittest import override_config, skip_unless diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/test_capabilities.py index f80f48a455..422361b62a 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -13,8 +13,7 @@ # limitations under the License. import synapse.rest.admin from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import capabilities +from synapse.rest.client import capabilities, login from tests import unittest from tests.unittest import override_config @@ -31,19 +30,22 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.url = b"/_matrix/client/r0/capabilities" hs = self.setup_test_homeserver() - self.store = hs.get_datastore() self.config = hs.config self.auth_handler = hs.get_auth_handler() return hs + def prepare(self, reactor, clock, hs): + self.localpart = "user" + self.password = "pass" + self.user = self.register_user(self.localpart, self.password) + def test_check_auth_required(self): channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401) def test_get_room_version_capabilities(self): - self.register_user("user", "pass") - access_token = self.login("user", "pass") + access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] @@ -58,10 +60,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): ) def test_get_change_password_capabilities_password_login(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) - access_token = self.login(user, password) + access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] @@ -71,12 +70,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"localdb_enabled": False}}) def test_get_change_password_capabilities_localdb_disabled(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) access_token = self.get_success( self.auth_handler.get_access_token_for_user_id( - user, device_id=None, valid_until_ms=None + self.user, device_id=None, valid_until_ms=None ) ) @@ -88,12 +84,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"enabled": False}}) def test_get_change_password_capabilities_password_disabled(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) access_token = self.get_success( self.auth_handler.get_access_token_for_user_id( - user, device_id=None, valid_until_ms=None + self.user, device_id=None, valid_until_ms=None ) ) @@ -103,13 +96,86 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) - def test_get_does_not_include_msc3244_fields_by_default(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) + def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self): + """Test that per default msc3283 is disabled server returns `m.change_password`.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities) + self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities) + self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities) + + @override_config({"experimental_features": {"msc3283_enabled": True}}) + def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self): + """Test if msc3283 is enabled server returns capabilities.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + + @override_config( + { + "enable_set_displayname": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_get_set_displayname_capabilities_displayname_disabled(self): + """Test if set displayname is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) + + @override_config( + { + "enable_set_avatar_url": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): + """Test if set avatar_url is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) + + @override_config( + { + "enable_3pid_changes": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_change_3pid_capabilities_3pid_disabled(self): + """Test if change 3pid is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + + @override_config({"experimental_features": {"msc3244_enabled": False}}) + def test_get_does_not_include_msc3244_fields_when_disabled(self): access_token = self.get_success( self.auth_handler.get_access_token_for_user_id( - user, device_id=None, valid_until_ms=None + self.user, device_id=None, valid_until_ms=None ) ) @@ -121,14 +187,10 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] ) - @override_config({"experimental_features": {"msc3244_enabled": True}}) def test_get_does_include_msc3244_fields_when_enabled(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) access_token = self.get_success( self.auth_handler.get_access_token_for_user_id( - user, device_id=None, valid_until_ms=None + self.user, device_id=None, valid_until_ms=None ) ) diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 5cc62a910a..65c58ce70a 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -16,7 +16,7 @@ import os import synapse.rest.admin from synapse.api.urls import ConsentURIBuilder -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.rest.consent import consent_resource from tests import unittest diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/test_directory.py index 8ed470490b..d2181ea907 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/test_directory.py @@ -15,7 +15,7 @@ import json from synapse.rest import admin -from synapse.rest.client.v1 import directory, login, room +from synapse.rest.client import directory, login, room from synapse.types import RoomAlias from synapse.util.stringutils import random_string diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index eec0fc01f9..3d7aa8ec86 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.api.constants import EventContentFields, EventTypes from synapse.rest import admin -from synapse.rest.client.v1 import room +from synapse.rest.client import room from tests import unittest diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/test_events.py index 2789d51546..a90294003e 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/test_events.py @@ -17,7 +17,7 @@ from unittest.mock import Mock import synapse.rest.admin -from synapse.rest.client.v1 import events, login, room +from synapse.rest.client import events, login, room from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/test_filter.py index c7e47725b7..475c6bed3d 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -15,7 +15,7 @@ from twisted.internet import defer from synapse.api.errors import Codes -from synapse.rest.client.v2_alpha import filter +from synapse.rest.client import filter from tests import unittest diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py new file mode 100644 index 0000000000..ad0425ae65 --- /dev/null +++ b/tests/rest/client/test_groups.py @@ -0,0 +1,56 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest.client import groups, room + +from tests import unittest +from tests.unittest import override_config + + +class GroupsTestCase(unittest.HomeserverTestCase): + user_id = "@alice:test" + room_creator_user_id = "@bob:test" + + servlets = [room.register_servlets, groups.register_servlets] + + @override_config({"enable_group_creation": True}) + def test_rooms_limited_by_visibility(self): + group_id = "+spqr:test" + + # Alice creates a group + channel = self.make_request("POST", "/create_group", {"localpart": "spqr"}) + self.assertEquals(channel.code, 200, msg=channel.text_body) + self.assertEquals(channel.json_body, {"group_id": group_id}) + + # Bob creates a private room + room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False) + self.helper.auth_user_id = self.room_creator_user_id + self.helper.send_state( + room_id, "m.room.name", {"name": "bob's secret room"}, tok=None + ) + self.helper.auth_user_id = self.user_id + + # Alice adds the room to her group. + channel = self.make_request( + "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {} + ) + self.assertEquals(channel.code, 200, msg=channel.text_body) + self.assertEquals(channel.json_body, {}) + + # Alice now tries to retrieve the room list of the space. + channel = self.make_request("GET", f"/groups/{group_id}/rooms") + self.assertEquals(channel.code, 200, msg=channel.text_body) + self.assertEquals( + channel.json_body, {"chunk": [], "total_room_count_estimate": 0} + ) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 478296ba0e..ca2e8ff8ef 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -15,7 +15,7 @@ import json import synapse.rest.admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py new file mode 100644 index 0000000000..d7fa635eae --- /dev/null +++ b/tests/rest/client/test_keys.py @@ -0,0 +1,91 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from http import HTTPStatus + +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client import keys, login + +from tests import unittest + + +class KeyQueryTestCase(unittest.HomeserverTestCase): + servlets = [ + keys.register_servlets, + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def test_rejects_device_id_ice_key_outside_of_list(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: "device_id1", + }, + }, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_rejects_device_key_given_as_map_to_bool(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: { + "device_id1": True, + }, + }, + }, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_requires_device_key(self): + """`device_keys` is required. We should complain if it's missing.""" + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + {}, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/test_login.py index 7eba69642a..5b2243fe52 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/test_login.py @@ -24,16 +24,15 @@ from twisted.web.resource import Resource import synapse.rest.admin from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login, logout -from synapse.rest.client.v2_alpha import devices, register -from synapse.rest.client.v2_alpha.account import WhoamiRestServlet +from synapse.rest.client import devices, login, logout, register +from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import create_requester from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 -from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/test_password_policy.py index 6f07ff6cbb..3cf5871899 100644 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ b/tests/rest/client/test_password_policy.py @@ -17,8 +17,7 @@ import json from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.rest import admin -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import account, password_policy, register +from synapse.rest.client import account, login, password_policy, register from tests import unittest diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index ba5ad47df5..c0de4c93a8 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.errors import Codes +from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests.unittest import HomeserverTestCase @@ -204,3 +205,79 @@ class PowerLevelsTestCase(HomeserverTestCase): tok=self.admin_access_token, expect_code=200, # expect success ) + + def test_cannot_set_string_power_levels(self): + room_power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, + ) + + # Update existing power levels with user at PL "0" + room_power_levels["users"].update({self.user_user_id: "0"}) + + body = self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + expect_code=400, # expect failure + ) + + self.assertEqual( + body["errcode"], + Codes.BAD_JSON, + body, + ) + + def test_cannot_set_unsafe_large_power_levels(self): + room_power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, + ) + + # Update existing power levels with user at PL above the max safe integer + room_power_levels["users"].update( + {self.user_user_id: CANONICALJSON_MAX_INT + 1} + ) + + body = self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + expect_code=400, # expect failure + ) + + self.assertEqual( + body["errcode"], + Codes.BAD_JSON, + body, + ) + + def test_cannot_set_unsafe_small_power_levels(self): + room_power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, + ) + + # Update existing power levels with user at PL below the minimum safe integer + room_power_levels["users"].update( + {self.user_user_id: CANONICALJSON_MIN_INT - 1} + ) + + body = self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + expect_code=400, # expect failure + ) + + self.assertEqual( + body["errcode"], + Codes.BAD_JSON, + body, + ) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/test_presence.py index 597e4c67de..1d152352d1 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -17,7 +17,7 @@ from unittest.mock import Mock from twisted.internet import defer from synapse.handlers.presence import PresenceHandler -from synapse.rest.client.v1 import presence +from synapse.rest.client import presence from synapse.types import UserID from tests import unittest diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/test_profile.py index 165ad33fb7..2860579c2e 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -14,7 +14,7 @@ """Tests REST events for /profile paths.""" from synapse.rest import admin -from synapse.rest.client.v1 import login, profile, room +from synapse.rest.client import login, profile, room from tests import unittest diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py index d077616082..d0ce91ccd9 100644 --- a/tests/rest/client/v1/test_push_rule_attrs.py +++ b/tests/rest/client/test_push_rule_attrs.py @@ -13,7 +13,7 @@ # limitations under the License. import synapse from synapse.api.errors import Codes -from synapse.rest.client.v1 import login, push_rule, room +from synapse.rest.client import login, push_rule, room from tests.unittest import HomeserverTestCase diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index dfd85221d0..433d715f69 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -13,8 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests.unittest import HomeserverTestCase diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/test_register.py index 1cad5f00eb..9f3ab2c985 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/test_register.py @@ -23,8 +23,8 @@ import synapse.rest.admin from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login, logout -from synapse.rest.client.v2_alpha import account, account_validity, register, sync +from synapse.rest.client import account, account_validity, login, logout, register, sync +from synapse.storage._base import db_to_json from tests import unittest from tests.unittest import override_config @@ -205,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config({"registration_requires_token": True}) + def test_POST_registration_requires_token(self): + username = "kermit" + device_id = "frogfone" + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + params = { + "username": username, + "password": "monkey", + "device_id": device_id, + } + + # Request without auth to get flows and session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + # Synapse adds a dummy stage to differentiate flows where otherwise one + # flow would be a subset of another flow. + self.assertCountEqual( + [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]], + (f["stages"] for f in flows), + ) + session = channel.json_body["session"] + + # Do the registration token stage and check it has completed + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + self.assertEquals(channel.result["code"], b"401", channel.result) + completed = channel.json_body["completed"] + self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) + + # Do the m.login.dummy stage and check registration was successful + params["auth"] = { + "type": LoginType.DUMMY, + "session": session, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + det_data = { + "user_id": f"@{username}:{self.hs.hostname}", + "home_server": self.hs.hostname, + "device_id": device_id, + } + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, channel.json_body) + + # Check the `completed` counter has been incremented and pending is 0 + res = self.get_success( + store.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) + ) + self.assertEquals(res["completed"], 1) + self.assertEquals(res["pending"], 0) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_invalid(self): + params = { + "username": "kermit", + "password": "monkey", + } + # Request without auth to get session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Test with token param missing (invalid) + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "session": session, + } + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM) + self.assertEquals(channel.json_body["completed"], []) + + # Test with non-string (invalid) + params["auth"]["token"] = 1234 + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM) + self.assertEquals(channel.json_body["completed"], []) + + # Test with unknown token (invalid) + params["auth"]["token"] = "1234" + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_limit_uses(self): + token = "abcd" + store = self.hs.get_datastore() + # Create token that can be used once + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + params1 = {"username": "bert", "password": "monkey"} + params2 = {"username": "ernie", "password": "monkey"} + # Do 2 requests without auth to get two session IDs + channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + session1 = channel1.json_body["session"] + channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + session2 = channel2.json_body["session"] + + # Use token with session1 and check `pending` is 1 + params1["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session1, + } + self.make_request(b"POST", self.url, json.dumps(params1)) + # Repeat request to make sure pending isn't increased again + self.make_request(b"POST", self.url, json.dumps(params1)) + pending = self.get_success( + store.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + ) + self.assertEquals(pending, 1) + + # Check auth fails when using token with session2 + params2["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session2, + } + channel = self.make_request(b"POST", self.url, json.dumps(params2)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + # Complete registration with session1 + params1["auth"]["type"] = LoginType.DUMMY + self.make_request(b"POST", self.url, json.dumps(params1)) + # Check pending=0 and completed=1 + res = self.get_success( + store.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) + ) + self.assertEquals(res["pending"], 0) + self.assertEquals(res["completed"], 1) + + # Check auth still fails when using token with session2 + channel = self.make_request(b"POST", self.url, json.dumps(params2)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_expiry(self): + token = "abcd" + now = self.hs.get_clock().time_msec() + store = self.hs.get_datastore() + # Create token that expired yesterday + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": now - 24 * 60 * 60 * 1000, + }, + ) + ) + params = {"username": "kermit", "password": "monkey"} + # Request without auth to get session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Check authentication fails with expired token + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + # Update token so it expires tomorrow + self.get_success( + store.db_pool.simple_update_one( + "registration_tokens", + keyvalues={"token": token}, + updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000}, + ) + ) + + # Check authentication succeeds + channel = self.make_request(b"POST", self.url, json.dumps(params)) + completed = channel.json_body["completed"] + self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_session_expiry(self): + """Test `pending` is decremented when an uncompleted session expires.""" + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + # Do 2 requests without auth to get two session IDs + params1 = {"username": "bert", "password": "monkey"} + params2 = {"username": "ernie", "password": "monkey"} + channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + session1 = channel1.json_body["session"] + channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + session2 = channel2.json_body["session"] + + # Use token with both sessions + params1["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session1, + } + self.make_request(b"POST", self.url, json.dumps(params1)) + + params2["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session2, + } + self.make_request(b"POST", self.url, json.dumps(params2)) + + # Complete registration with session1 + params1["auth"]["type"] = LoginType.DUMMY + self.make_request(b"POST", self.url, json.dumps(params1)) + + # Check `result` of registration token stage for session1 is `True` + result1 = self.get_success( + store.db_pool.simple_select_one_onecol( + "ui_auth_sessions_credentials", + keyvalues={ + "session_id": session1, + "stage_type": LoginType.REGISTRATION_TOKEN, + }, + retcol="result", + ) + ) + self.assertTrue(db_to_json(result1)) + + # Check `result` for session2 is the token used + result2 = self.get_success( + store.db_pool.simple_select_one_onecol( + "ui_auth_sessions_credentials", + keyvalues={ + "session_id": session2, + "stage_type": LoginType.REGISTRATION_TOKEN, + }, + retcol="result", + ) + ) + self.assertEquals(db_to_json(result2), token) + + # Delete both sessions (mimics expiry) + self.get_success( + store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) + ) + + # Check pending is now 0 + pending = self.get_success( + store.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + ) + self.assertEquals(pending, 0) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_session_expiry_deleted_token(self): + """Test session expiry doesn't break when the token is deleted. + + 1. Start but don't complete UIA with a registration token + 2. Delete the token from the database + 3. Expire the session + """ + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + # Do request without auth to get a session ID + params = {"username": "kermit", "password": "monkey"} + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Use token + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + self.make_request(b"POST", self.url, json.dumps(params)) + + # Delete token + self.get_success( + store.db_pool.simple_delete_one( + "registration_tokens", + keyvalues={"token": token}, + ) + ) + + # Delete session (mimics expiry) + self.get_success( + store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) + ) + def test_advertised_flows(self): channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) @@ -509,10 +874,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): } # Email config. - self.email_attempts = [] - - async def sendmail(*args, **kwargs): - self.email_attempts.append((args, kwargs)) config["email"] = { "enable_notifs": True, @@ -532,7 +893,13 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "aaa" - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail(*args, **kwargs): + self.email_attempts.append((args, kwargs)) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail self.store = self.hs.get_datastore() @@ -743,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) self.assertLessEqual(res, now_ms + self.validity_period) + + +class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): + servlets = [register.register_servlets] + url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity" + + def default_config(self): + config = super().default_config() + config["registration_requires_token"] = True + return config + + def test_GET_token_valid(self): + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEquals(channel.json_body["valid"], True) + + def test_GET_token_invalid(self): + token = "1234" + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEquals(channel.json_body["valid"], False) + + @override_config( + {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}} + ) + def test_GET_ratelimiting(self): + token = "1234" + + for i in range(0, 6): + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/test_relations.py index 2e2f94742e..02b5e9a8d0 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -19,8 +19,7 @@ from typing import Optional from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import register, relations +from synapse.rest.client import login, register, relations, room from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/test_report_event.py index a76a6fef1e..ee6b0b9ebf 100644 --- a/tests/rest/client/v2_alpha/test_report_event.py +++ b/tests/rest/client/test_report_event.py @@ -15,8 +15,7 @@ import json import synapse.rest.admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import report_event +from synapse.rest.client import login, report_event, room from tests import unittest diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index e1a6e73e17..b58452195a 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -15,7 +15,7 @@ from unittest.mock import Mock from synapse.api.constants import EventTypes from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.visibility import filter_events_for_client from tests import unittest diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/test_rooms.py index 3df070c936..50100a5ae4 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -19,15 +19,17 @@ import json from typing import Iterable -from unittest.mock import Mock +from unittest.mock import Mock, call from urllib import parse as urlparse +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.errors import HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin -from synapse.rest.client.v1 import directory, login, profile, room -from synapse.rest.client.v2_alpha import account +from synapse.rest.client import account, directory, login, profile, room, sync from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string @@ -379,6 +381,8 @@ class RoomPermissionsTestCase(RoomBase): class RoomsMemberListTestCase(RoomBase): """Tests /rooms/$room_id/members/list REST events.""" + servlets = RoomBase.servlets + [sync.register_servlets] + user_id = "@sid1:red" def test_get_member_list(self): @@ -395,6 +399,86 @@ class RoomsMemberListTestCase(RoomBase): channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEquals(403, channel.code, msg=channel.result["body"]) + def test_get_member_list_no_permission_with_at_token(self): + """ + Tests that a stranger to the room cannot get the member list + (in the case that they use an at token). + """ + room_id = self.helper.create_room_as("@someone.else:red") + + # first sync to get an at token + channel = self.make_request("GET", "/sync") + self.assertEquals(200, channel.code) + sync_token = channel.json_body["next_batch"] + + # check that permission is denied for @sid1:red to get the + # memberships of @someone.else:red's room. + channel = self.make_request( + "GET", + f"/rooms/{room_id}/members?at={sync_token}", + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def test_get_member_list_no_permission_former_member(self): + """ + Tests that a former member of the room can not get the member list. + """ + # create a room, invite the user and the user joins + room_id = self.helper.create_room_as("@alice:red") + self.helper.invite(room_id, "@alice:red", self.user_id) + self.helper.join(room_id, self.user_id) + + # check that the user can see the member list to start with + channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # ban the user + self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") + + # check the user can no longer see the member list + channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def test_get_member_list_no_permission_former_member_with_at_token(self): + """ + Tests that a former member of the room can not get the member list + (in the case that they use an at token). + """ + # create a room, invite the user and the user joins + room_id = self.helper.create_room_as("@alice:red") + self.helper.invite(room_id, "@alice:red", self.user_id) + self.helper.join(room_id, self.user_id) + + # sync to get an at token + channel = self.make_request("GET", "/sync") + self.assertEquals(200, channel.code) + sync_token = channel.json_body["next_batch"] + + # check that the user can see the member list to start with + channel = self.make_request( + "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) + ) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # ban the user (Note: the user is actually allowed to see this event and + # state so that they know they're banned!) + self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") + + # invite a third user and let them join + self.helper.invite(room_id, "@alice:red", "@bob:red") + self.helper.join(room_id, "@bob:red") + + # now, with the original user, sync again to get a new at token + channel = self.make_request("GET", "/sync") + self.assertEquals(200, channel.code) + sync_token = channel.json_body["next_batch"] + + # check the user can no longer see the updated member list + channel = self.make_request( + "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + def test_get_member_list_mixed_memberships(self): room_creator = "@some_other_guy:red" room_id = self.helper.create_room_as(room_creator) @@ -1124,6 +1208,93 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) +class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): + """Test that we correctly fallback to local filtering if a remote server + doesn't support search. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(federation_client=Mock()) + + def prepare(self, reactor, clock, hs): + self.register_user("user", "pass") + self.token = self.login("user", "pass") + + self.federation_client = hs.get_federation_client() + + def test_simple(self): + "Simple test for searching rooms over federation" + self.federation_client.get_public_rooms.side_effect = ( + lambda *a, **k: defer.succeed({}) + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_called_once_with( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ) + + def test_fallback(self): + "Test that searching public rooms over federation falls back if it gets a 404" + + # The `get_public_rooms` should be called again if the first call fails + # with a 404, when using search filters. + self.federation_client.get_public_rooms.side_effect = ( + HttpResponseException(404, "Not Found", b""), + defer.succeed({}), + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_has_calls( + [ + call( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ), + call( + "testserv", + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ), + ] + ) + + class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py index c9c99cc5d7..6db7062a8e 100644 --- a/tests/rest/client/v2_alpha/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py @@ -13,8 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import sendtodevice, sync +from synapse.rest.client import login, sendtodevice, sync from tests.unittest import HomeserverTestCase, override_config diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 288ee12888..6a0d9a82be 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -16,8 +16,13 @@ from unittest.mock import Mock, patch import synapse.rest.admin from synapse.api.constants import EventTypes -from synapse.rest.client.v1 import directory, login, profile, room -from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet +from synapse.rest.client import ( + directory, + login, + profile, + room, + room_upgrade_rest_servlet, +) from synapse.types import UserID from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py index cedb9614a8..283eccd53f 100644 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ b/tests/rest/client/test_shared_rooms.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import synapse.rest.admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import shared_rooms +from synapse.rest.client import login, room, shared_rooms from tests import unittest from tests.server import FakeChannel diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/test_sync.py index f6ae9ae181..95be369d4b 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -21,8 +21,7 @@ from synapse.api.constants import ( ReadReceiptEventFields, RelationTypes, ) -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import knock, read_marker, receipts, sync +from synapse.rest.client import knock, login, read_marker, receipts, room, sync from tests import unittest from tests.federation.transport.test_knocking import ( @@ -418,6 +417,18 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that the first user can't see the other user's hidden read receipt self.assertEqual(self._get_read_receipt(), None) + def test_read_receipt_with_empty_body(self): + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt for this message with an empty body + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + def _get_read_receipt(self): """Syncs and returns the read receipt.""" diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 28dd47a28b..0ae4029640 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -19,7 +19,7 @@ from synapse.events import EventBase from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.module_api import ModuleApi from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import Requester, StateMap from synapse.util.frozenutils import unfreeze diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/test_typing.py index 44e22ca999..b54b004733 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -17,7 +17,7 @@ from unittest.mock import Mock -from synapse.rest.client.v1 import room +from synapse.rest.client import room from synapse.types import UserID from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index 5f3f15fc57..72f976d8e2 100644 --- a/tests/rest/client/v2_alpha/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -15,8 +15,7 @@ from typing import Optional from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet +from synapse.rest.client import login, room, room_upgrade_rest_servlet from tests import unittest from tests.server import FakeChannel diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/utils.py index fc2d35596e..954ad1a1fd 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/utils.py @@ -47,10 +47,10 @@ class RestHelper: def create_room_as( self, - room_creator: str = None, + room_creator: Optional[str] = None, is_public: bool = True, - room_version: str = None, - tok: str = None, + room_version: Optional[str] = None, + tok: Optional[str] = None, expect_code: int = 200, extra_content: Optional[Dict] = None, custom_headers: Optional[ diff --git a/tests/rest/client/v1/__init__.py b/tests/rest/client/v1/__init__.py deleted file mode 100644 index 5e83dba2ed..0000000000 --- a/tests/rest/client/v1/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 --- a/tests/rest/client/v2_alpha/__init__.py +++ /dev/null diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 2d6b49692e..2f7eebfe69 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -21,7 +21,7 @@ from unittest.mock import Mock from urllib import parse import attr -from parameterized import parameterized_class +from parameterized import parameterized, parameterized_class from PIL import Image as Image from twisted.internet import defer @@ -30,7 +30,7 @@ from twisted.internet.defer import Deferred from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.logging.context import make_deferred_yieldable from synapse.rest import admin -from synapse.rest.client.v1 import login +from synapse.rest.client import login from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage @@ -473,6 +473,43 @@ class MediaRepoTests(unittest.HomeserverTestCase): }, ) + @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)]) + def test_same_quality(self, method, desired_size): + """Test that choosing between thumbnails with the same quality rating succeeds. + + We are not particular about which thumbnail is chosen.""" + self.assertIsNotNone( + self.thumbnail_resource._select_thumbnail( + desired_width=desired_size, + desired_height=desired_size, + desired_method=method, + desired_type=self.test_image.content_type, + # Provide two identical thumbnails which are guaranteed to have the same + # quality rating. + thumbnail_infos=[ + { + "thumbnail_width": 32, + "thumbnail_height": 32, + "thumbnail_method": method, + "thumbnail_type": self.test_image.content_type, + "thumbnail_length": 256, + "filesystem_id": f"thumbnail1{self.test_image.extension}", + }, + { + "thumbnail_width": 32, + "thumbnail_height": 32, + "thumbnail_method": method, + "thumbnail_type": self.test_image.content_type, + "thumbnail_length": 256, + "filesystem_id": f"thumbnail2{self.test_image.extension}", + }, + ], + file_id=f"image{self.test_image.extension}", + url_cache=None, + server_name=None, + ) + ) + def test_x_robots_tag_header(self): """ Tests that the `X-Robots-Tag` header is present, which informs web crawlers diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index d3ef7bb4c6..7fa9027227 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -14,13 +14,14 @@ import json import os import re -from unittest.mock import patch from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.error import DNSLookupError from twisted.test.proto_helpers import AccumulatingProtocol +from synapse.config.oembed import OEmbedEndpointConfig + from tests import unittest from tests.server import FakeTransport @@ -81,6 +82,19 @@ class URLPreviewTests(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) + # After the hs is created, modify the parsed oEmbed config (to avoid + # messing with files). + # + # Note that HTTP URLs are used to avoid having to deal with TLS in tests. + hs.config.oembed.oembed_patterns = [ + OEmbedEndpointConfig( + api_endpoint="http://publish.twitter.com/oembed", + url_patterns=[ + re.compile(r"http://twitter\.com/.+/status/.+"), + ], + ) + ] + return hs def prepare(self, reactor, clock, hs): @@ -544,123 +558,101 @@ class URLPreviewTests(unittest.HomeserverTestCase): def test_oembed_photo(self): """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL.""" - # Route the HTTP version to an HTTP endpoint so that the tests work. - with patch.dict( - "synapse.rest.media.v1.preview_url_resource._oembed_patterns", - { - re.compile( - r"http://twitter\.com/.+/status/.+" - ): "http://publish.twitter.com/oembed", - }, - clear=True, - ): - - self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] - self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] - - result = { - "version": "1.0", - "type": "photo", - "url": "http://cdn.twitter.com/matrixdotorg", - } - oembed_content = json.dumps(result).encode("utf-8") - - end_content = ( - b"<html><head>" - b"<title>Some Title</title>" - b'<meta property="og:description" content="hi" />' - b"</head></html>" - ) + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] + self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] - channel = self.make_request( - "GET", - "preview_url?url=http://twitter.com/matrixdotorg/status/12345", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: application/json; charset="utf8"\r\n\r\n' - ) - % (len(oembed_content),) - + oembed_content - ) + result = { + "version": "1.0", + "type": "photo", + "url": "http://cdn.twitter.com/matrixdotorg", + } + oembed_content = json.dumps(result).encode("utf-8") - self.pump() - - client = self.reactor.tcpClients[1][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: text/html; charset="utf8"\r\n\r\n' - ) - % (len(end_content),) - + end_content + end_content = ( + b"<html><head>" + b"<title>Some Title</title>" + b'<meta property="og:description" content="hi" />' + b"</head></html>" + ) + + channel = self.make_request( + "GET", + "preview_url?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: application/json; charset="utf8"\r\n\r\n' ) + % (len(oembed_content),) + + oembed_content + ) - self.pump() + self.pump() - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body, {"og:title": "Some Title", "og:description": "hi"} + client = self.reactor.tcpClients[1][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' ) + % (len(end_content),) + + end_content + ) + + self.pump() + + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "Some Title", "og:description": "hi"} + ) def test_oembed_rich(self): """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" - # Route the HTTP version to an HTTP endpoint so that the tests work. - with patch.dict( - "synapse.rest.media.v1.preview_url_resource._oembed_patterns", - { - re.compile( - r"http://twitter\.com/.+/status/.+" - ): "http://publish.twitter.com/oembed", - }, - clear=True, - ): - - self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] - - result = { - "version": "1.0", - "type": "rich", - "html": "<div>Content Preview</div>", - } - end_content = json.dumps(result).encode("utf-8") - - channel = self.make_request( - "GET", - "preview_url?url=http://twitter.com/matrixdotorg/status/12345", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: application/json; charset="utf8"\r\n\r\n' - ) - % (len(end_content),) - + end_content - ) + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] + + result = { + "version": "1.0", + "type": "rich", + "html": "<div>Content Preview</div>", + } + end_content = json.dumps(result).encode("utf-8") + + channel = self.make_request( + "GET", + "preview_url?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + await_result=False, + ) + self.pump() - self.pump() - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body, - {"og:title": None, "og:description": "Content Preview"}, + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: application/json; charset="utf8"\r\n\r\n' ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, + {"og:title": None, "og:description": "Content Preview"}, + ) diff --git a/tests/server.py b/tests/server.py index 6fddd3b305..b861c7b866 100644 --- a/tests/server.py +++ b/tests/server.py @@ -10,9 +10,10 @@ from zope.interface import implementer from twisted.internet import address, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier -from twisted.internet.defer import Deferred, fail, succeed +from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( + IAddress, IHostnameResolver, IProtocol, IPullProducer, @@ -511,6 +512,9 @@ class FakeTransport: will get called back for connectionLost() notifications etc. """ + _peer_address: Optional[IAddress] = attr.ib(default=None) + """The value to be returend by getPeer""" + disconnecting = False disconnected = False connected = True @@ -519,7 +523,7 @@ class FakeTransport: autoflush = attr.ib(default=True) def getPeer(self): - return None + return self._peer_address def getHost(self): return None @@ -572,7 +576,12 @@ class FakeTransport: self.producerStreaming = streaming def _produce(): - d = self.producer.resumeProducing() + if not self.producer: + # we've been unregistered + return + # some implementations of IProducer (for example, FileSender) + # don't return a deferred. + d = maybeDeferred(self.producer.resumeProducing) d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) if not streaming: diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index ac98259b7e..58b399a043 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -15,8 +15,7 @@ import os import synapse.rest.admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests import unittest diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 3245aa91ca..8701b5f7e3 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -19,8 +19,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 932970fd9a..a649e8c618 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -14,7 +14,10 @@ import json from synapse.logging.context import LoggingContext +from synapse.rest import admin +from synapse.rest.client import login, room from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util.async_helpers import yieldable_gather_results from tests import unittest @@ -94,3 +97,50 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): res = self.get_success(self.store.have_seen_events("room1", ["event10"])) self.assertEquals(res, {"event10"}) self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + + +class EventCacheTestCase(unittest.HomeserverTestCase): + """Test that the various layers of event cache works.""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store: EventsWorkerStore = hs.get_datastore() + + self.user = self.register_user("user", "pass") + self.token = self.login(self.user, "pass") + + self.room = self.helper.create_room_as(self.user, tok=self.token) + + res = self.helper.send(self.room, tok=self.token) + self.event_id = res["event_id"] + + # Reset the event cache so the tests start with it empty + self.store._get_event_cache.clear() + + def test_simple(self): + """Test that we cache events that we pull from the DB.""" + + with LoggingContext("test") as ctx: + self.get_success(self.store.get_event(self.event_id)) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) + + def test_dedupe(self): + """Test that if we request the same event multiple times we only pull it + out once. + """ + + with LoggingContext("test") as ctx: + d = yieldable_gather_results( + self.store.get_event, [self.event_id, self.event_id] + ) + self.get_success(d) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py new file mode 100644 index 0000000000..ffee707153 --- /dev/null +++ b/tests/storage/databases/main/test_room.py @@ -0,0 +1,98 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.storage.databases.main.room import _BackgroundUpdates + +from tests.unittest import HomeserverTestCase + + +class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.user_id = self.register_user("foo", "pass") + self.token = self.login("foo", "pass") + + def _generate_room(self) -> str: + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + return room_id + + def test_background_populate_rooms_creator_column(self): + """Test that the background update to populate the rooms creator column + works properly. + """ + + # Insert a room without the creator + room_id = self._generate_room() + self.get_success( + self.store.db_pool.simple_update( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"creator": None}, + desc="test", + ) + ) + + # Make sure the test is starting out with a room without a creator + room_creator_before = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="rooms", + keyvalues={"room_id": room_id}, + retcol="creator", + allow_none=True, + ) + ) + self.assertEqual(room_creator_before, None) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN, + "progress_json": "{}", + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Make sure the background update filled in the room creator + room_creator_after = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="rooms", + keyvalues={"room_id": room_id}, + retcol="creator", + allow_none=True, + ) + ) + self.assertEqual(room_creator_after, self.user_id) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 77c4fe721c..da98733ce8 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -17,7 +17,7 @@ from unittest.mock import Mock, patch import synapse.rest.admin from synapse.api.constants import EventTypes -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage import prepare_database from synapse.types import UserID, create_requester diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index e57fce9694..1c2df54ecc 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -17,7 +17,7 @@ from unittest.mock import Mock import synapse.rest.admin from synapse.http.site import XForwardedForRequest -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest from tests.server import make_request diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index d87f124c26..93136f0717 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage.databases.main.events import _LinkMap from synapse.types import create_requester diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index a0e2259478..c3fcf7e7b4 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -15,7 +15,9 @@ import attr from parameterized import parameterized +from synapse.api.room_versions import RoomVersions from synapse.events import _EventInternalMetadata +from synapse.util import json_encoder import tests.unittest import tests.utils @@ -504,6 +506,61 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) self.assertSetEqual(difference, set()) + def test_prune_inbound_federation_queue(self): + "Test that pruning of inbound federation queues work" + + room_id = "some_room_id" + + # Insert a bunch of events that all reference the previous one. + self.get_success( + self.store.db_pool.simple_insert_many( + table="federation_inbound_events_staging", + values=[ + { + "origin": "some_origin", + "room_id": room_id, + "received_ts": 0, + "event_id": f"$fake_event_id_{i + 1}", + "event_json": json_encoder.encode( + {"prev_events": [f"$fake_event_id_{i}"]} + ), + "internal_metadata": "{}", + } + for i in range(500) + ], + desc="test_prune_inbound_federation_queue", + ) + ) + + # Calling prune once should return True, i.e. a prune happen. The second + # time it shouldn't. + pruned = self.get_success( + self.store.prune_staged_events_in_room(room_id, RoomVersions.V6) + ) + self.assertTrue(pruned) + + pruned = self.get_success( + self.store.prune_staged_events_in_room(room_id, RoomVersions.V6) + ) + self.assertFalse(pruned) + + # Assert that we only have a single event left in the queue, and that it + # is the last one. + count = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="federation_inbound_events_staging", + keyvalues={"room_id": room_id}, + retcol="COALESCE(COUNT(*), 0)", + desc="test_prune_inbound_federation_queue", + ) + ) + self.assertEqual(count, 1) + + _, event_id = self.get_success( + self.store.get_next_staged_event_id_for_room(room_id) + ) + self.assertEqual(event_id, "$fake_event_id_500") + @attr.s class FakeEvent: diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 617bc8091f..f462a8b1c7 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -17,7 +17,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.federation.federation_base import event_from_pdu_json from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.unittest import HomeserverTestCase diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index e5574063f1..22a77c3ccc 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.api.errors import NotFoundError, SynapseError -from synapse.rest.client.v1 import room +from synapse.rest.client import room from tests.unittest import HomeserverTestCase diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 9fa968f6bb..c72dc40510 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -15,7 +15,7 @@ from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests import unittest diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py index 9be51f9ebd..6ff3ebb137 100644 --- a/tests/storage/test_txn_limit.py +++ b/tests/storage/test_txn_limit.py @@ -32,5 +32,5 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase): db_pool = self.hs.get_datastores().databases[0] # force txn limit to roll over at least once - for i in range(0, 1001): + for _ in range(0, 1001): self.get_success_or_raise(db_pool.runInteraction("test_select", do_select)) diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index e5550aec4d..6ebd01bcbe 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -384,7 +384,7 @@ class EventAuthTestCase(unittest.TestCase): }, ) event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, authorised_join_event, auth_events, do_sig_check=False, @@ -400,7 +400,7 @@ class EventAuthTestCase(unittest.TestCase): "@inviter:foo.test" ) event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, _join_event( pleb, additional_content={ @@ -414,7 +414,7 @@ class EventAuthTestCase(unittest.TestCase): # A join which is missing an authorised server is rejected. with self.assertRaises(AuthError): event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, _join_event(pleb), auth_events, do_sig_check=False, @@ -427,7 +427,7 @@ class EventAuthTestCase(unittest.TestCase): ) with self.assertRaises(AuthError): event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, _join_event( pleb, additional_content={ @@ -442,7 +442,7 @@ class EventAuthTestCase(unittest.TestCase): # *would* be valid, but is sent be a different user.) with self.assertRaises(AuthError): event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, _member_event( pleb, "join", @@ -459,7 +459,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, authorised_join_event, auth_events, do_sig_check=False, @@ -468,7 +468,7 @@ class EventAuthTestCase(unittest.TestCase): # A user who left can re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, authorised_join_event, auth_events, do_sig_check=False, @@ -478,7 +478,7 @@ class EventAuthTestCase(unittest.TestCase): # be authorised since the user is already joined.) auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, _join_event(pleb), auth_events, do_sig_check=False, @@ -490,7 +490,7 @@ class EventAuthTestCase(unittest.TestCase): pleb, "invite", sender=creator ) event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, _join_event(pleb), auth_events, do_sig_check=False, diff --git a/tests/test_federation.py b/tests/test_federation.py index 0ed8326f55..61c9d7c2ef 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -75,10 +75,9 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) self.handler = self.homeserver.get_federation_handler() - self.handler._check_event_auth = ( - lambda origin, event, context, state, auth_events, backfilled: succeed( - context - ) + federation_event_handler = self.homeserver.get_federation_event_handler() + federation_event_handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed( + context ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( @@ -88,9 +87,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Send the join, it should return None (which is not an error) self.assertEqual( self.get_success( - self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True - ) + federation_event_handler.on_receive_pdu("test.serv", join_event) ), None, ) @@ -135,11 +132,10 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) + federation_event_handler = self.homeserver.get_federation_event_handler() with LoggingContext("test-context"): failure = self.get_failure( - self.handler.on_receive_pdu( - "test.serv", lying_event, sent_to_us_directly=True - ), + federation_event_handler.on_receive_pdu("test.serv", lying_event), FederationError, ) diff --git a/tests/test_mau.py b/tests/test_mau.py index fa6ef92b3b..66111eb367 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -17,7 +17,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.appservice import ApplicationService -from synapse.rest.client.v2_alpha import register, sync +from synapse.rest.client import register, sync from tests import unittest from tests.unittest import override_config diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 0df480db9f..67dcf567cd 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -17,7 +17,7 @@ from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactorClock -from synapse.rest.client.v2_alpha.register import register_servlets +from synapse.rest.client.register import register_servlets from synapse.util import Clock from tests import unittest diff --git a/tests/unittest.py b/tests/unittest.py index 3eec9c4d5b..f2c90cc47b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -252,7 +252,7 @@ class HomeserverTestCase(TestCase): reactor=self.reactor, ) - from tests.rest.client.v1.utils import RestHelper + from tests.rest.client.utils import RestHelper self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None)) |