From ecd823d766fecdd7e1c7163073c097a0084122e2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 20 Aug 2021 17:50:44 +0100 Subject: Flatten tests/rest/client/{v1,v2_alpha} too (#10667) --- changelog.d/10667.misc | 1 + mypy.ini | 4 +- tests/rest/client/test_account.py | 1000 +++++++++ tests/rest/client/test_auth.py | 717 +++++++ tests/rest/client/test_capabilities.py | 212 ++ tests/rest/client/test_directory.py | 154 ++ tests/rest/client/test_events.py | 156 ++ tests/rest/client/test_filter.py | 111 + tests/rest/client/test_keys.py | 91 + tests/rest/client/test_login.py | 1345 ++++++++++++ tests/rest/client/test_password_policy.py | 177 ++ tests/rest/client/test_presence.py | 76 + tests/rest/client/test_profile.py | 270 +++ tests/rest/client/test_push_rule_attrs.py | 414 ++++ tests/rest/client/test_register.py | 746 +++++++ tests/rest/client/test_relations.py | 724 +++++++ tests/rest/client/test_report_event.py | 82 + tests/rest/client/test_rooms.py | 2150 ++++++++++++++++++++ tests/rest/client/test_sendtodevice.py | 200 ++ tests/rest/client/test_shared_rooms.py | 142 ++ tests/rest/client/test_sync.py | 686 +++++++ tests/rest/client/test_typing.py | 121 ++ tests/rest/client/test_upgrade_room.py | 159 ++ tests/rest/client/utils.py | 654 ++++++ tests/rest/client/v1/__init__.py | 13 - tests/rest/client/v1/test_directory.py | 154 -- tests/rest/client/v1/test_events.py | 156 -- tests/rest/client/v1/test_login.py | 1345 ------------ tests/rest/client/v1/test_presence.py | 76 - tests/rest/client/v1/test_profile.py | 270 --- tests/rest/client/v1/test_push_rule_attrs.py | 414 ---- tests/rest/client/v1/test_rooms.py | 2150 -------------------- tests/rest/client/v1/test_typing.py | 121 -- tests/rest/client/v1/utils.py | 654 ------ tests/rest/client/v2_alpha/__init__.py | 0 tests/rest/client/v2_alpha/test_account.py | 1000 --------- tests/rest/client/v2_alpha/test_auth.py | 717 ------- tests/rest/client/v2_alpha/test_capabilities.py | 212 -- tests/rest/client/v2_alpha/test_filter.py | 111 - tests/rest/client/v2_alpha/test_keys.py | 91 - tests/rest/client/v2_alpha/test_password_policy.py | 177 -- tests/rest/client/v2_alpha/test_register.py | 746 ------- tests/rest/client/v2_alpha/test_relations.py | 724 ------- tests/rest/client/v2_alpha/test_report_event.py | 82 - tests/rest/client/v2_alpha/test_sendtodevice.py | 200 -- tests/rest/client/v2_alpha/test_shared_rooms.py | 142 -- tests/rest/client/v2_alpha/test_sync.py | 686 ------- tests/rest/client/v2_alpha/test_upgrade_room.py | 159 -- tests/unittest.py | 2 +- 49 files changed, 10391 insertions(+), 10403 deletions(-) create mode 100644 changelog.d/10667.misc create mode 100644 tests/rest/client/test_account.py create mode 100644 tests/rest/client/test_auth.py create mode 100644 tests/rest/client/test_capabilities.py create mode 100644 tests/rest/client/test_directory.py create mode 100644 tests/rest/client/test_events.py create mode 100644 tests/rest/client/test_filter.py create mode 100644 tests/rest/client/test_keys.py create mode 100644 tests/rest/client/test_login.py create mode 100644 tests/rest/client/test_password_policy.py create mode 100644 tests/rest/client/test_presence.py create mode 100644 tests/rest/client/test_profile.py create mode 100644 tests/rest/client/test_push_rule_attrs.py create mode 100644 tests/rest/client/test_register.py create mode 100644 tests/rest/client/test_relations.py create mode 100644 tests/rest/client/test_report_event.py create mode 100644 tests/rest/client/test_rooms.py create mode 100644 tests/rest/client/test_sendtodevice.py create mode 100644 tests/rest/client/test_shared_rooms.py create mode 100644 tests/rest/client/test_sync.py create mode 100644 tests/rest/client/test_typing.py create mode 100644 tests/rest/client/test_upgrade_room.py create mode 100644 tests/rest/client/utils.py delete mode 100644 tests/rest/client/v1/__init__.py delete mode 100644 tests/rest/client/v1/test_directory.py delete mode 100644 tests/rest/client/v1/test_events.py delete mode 100644 tests/rest/client/v1/test_login.py delete mode 100644 tests/rest/client/v1/test_presence.py delete mode 100644 tests/rest/client/v1/test_profile.py delete mode 100644 tests/rest/client/v1/test_push_rule_attrs.py delete mode 100644 tests/rest/client/v1/test_rooms.py delete mode 100644 tests/rest/client/v1/test_typing.py delete mode 100644 tests/rest/client/v1/utils.py delete mode 100644 tests/rest/client/v2_alpha/__init__.py delete mode 100644 tests/rest/client/v2_alpha/test_account.py delete mode 100644 tests/rest/client/v2_alpha/test_auth.py delete mode 100644 tests/rest/client/v2_alpha/test_capabilities.py delete mode 100644 tests/rest/client/v2_alpha/test_filter.py delete mode 100644 tests/rest/client/v2_alpha/test_keys.py delete mode 100644 tests/rest/client/v2_alpha/test_password_policy.py delete mode 100644 tests/rest/client/v2_alpha/test_register.py delete mode 100644 tests/rest/client/v2_alpha/test_relations.py delete mode 100644 tests/rest/client/v2_alpha/test_report_event.py delete mode 100644 tests/rest/client/v2_alpha/test_sendtodevice.py delete mode 100644 tests/rest/client/v2_alpha/test_shared_rooms.py delete mode 100644 tests/rest/client/v2_alpha/test_sync.py delete mode 100644 tests/rest/client/v2_alpha/test_upgrade_room.py diff --git a/changelog.d/10667.misc b/changelog.d/10667.misc new file mode 100644 index 0000000000..c92846ae26 --- /dev/null +++ b/changelog.d/10667.misc @@ -0,0 +1 @@ +Flatten the `tests.synapse.rests` package by moving the contents of `v1` and `v2_alpha` into the parent. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 90ade37b3f..b17872211e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -91,8 +91,8 @@ files = tests/handlers/test_password_providers.py, tests/handlers/test_room_summary.py, tests/handlers/test_sync.py, - tests/rest/client/v1/test_login.py, - tests/rest/client/v2_alpha/test_auth.py, + tests/rest/client/test_login.py, + tests/rest/client/test_auth.py, tests/util/test_itertools.py, tests/util/test_stream_change_cache.py diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py new file mode 100644 index 0000000000..b946fca8b3 --- /dev/null +++ b/tests/rest/client/test_account.py @@ -0,0 +1,1000 @@ +# Copyright 2015-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import re +from email.parser import Parser +from typing import Optional + +import pkg_resources + +import synapse.rest.admin +from synapse.api.constants import LoginType, Membership +from synapse.api.errors import Codes, HttpResponseException +from synapse.appservice import ApplicationService +from synapse.rest.client import account, login, register, room +from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource + +from tests import unittest +from tests.server import FakeSite, make_request +from tests.unittest import override_config + + +class PasswordResetTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + hs.get_send_email_handler()._sendmail = sendmail + + return hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.submit_token_resource = PasswordResetSubmitTokenResource(hs) + + def test_basic_password_reset(self): + """Test basic password reset flow""" + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + # Assert we can log in with the new password + self.login("kermit", new_password) + + # Assert we can't log in with the old password + self.attempt_wrong_password_login("kermit", old_password) + + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_email(self): + """Test that we ratelimit /requestToken for the same email.""" + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test1@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + def reset(ip): + client_secret = "foobar" + session_id = self._request_token(email, client_secret, ip) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + self.email_attempts.clear() + + # We expect to be able to make three requests before getting rate + # limited. + # + # We change IPs to ensure that we're not being ratelimited due to the + # same IP + reset("127.0.0.1") + reset("127.0.0.2") + reset("127.0.0.3") + + with self.assertRaises(HttpResponseException) as cm: + reset("127.0.0.4") + + self.assertEqual(cm.exception.code, 429) + + def test_basic_password_reset_canonicalise_email(self): + """Test basic password reset flow + Request password reset with different spelling + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email_profile = "test@example.com" + email_passwort_reset = "TEST@EXAMPLE.COM" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email_profile, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email_passwort_reset, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + # Assert we can log in with the new password + self.login("kermit", new_password) + + # Assert we can't log in with the old password + self.attempt_wrong_password_login("kermit", old_password) + + def test_cant_reset_password_without_clicking_link(self): + """Test that we do actually need to click the link in the email""" + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to reset password without clicking the link + self._reset_password(new_password, session_id, client_secret, expected_code=401) + + # Assert we can log in with the old password + self.login("kermit", old_password) + + # Assert we can't log in with the new password + self.attempt_wrong_password_login("kermit", new_password) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = "weasle" + + # Attempt to reset password without even requesting an email + self._reset_password(new_password, session_id, client_secret, expected_code=401) + + # Assert we can log in with the old password + self.login("kermit", old_password) + + # Assert we can't log in with the new password + self.attempt_wrong_password_login("kermit", new_password) + + @unittest.override_config({"request_token_inhibit_3pid_errors": True}) + def test_password_reset_bad_email_inhibit_error(self): + """Test that triggering a password reset with an email address that isn't bound + to an account doesn't leak the lack of binding for that address if configured + that way. + """ + self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertIsNotNone(session_id) + + def _request_token(self, email, client_secret, ip="127.0.0.1"): + channel = self.make_request( + "POST", + b"account/password/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + client_ip=ip, + ) + + if channel.code != 200: + raise HttpResponseException( + channel.code, + channel.result["reason"], + channel.result["body"], + ) + + return channel.json_body["sid"] + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + # Load the password reset confirmation page + channel = make_request( + self.reactor, + FakeSite(self.submit_token_resource), + "GET", + path, + shorthand=False, + ) + + self.assertEquals(200, channel.code, channel.result) + + # Now POST to the same endpoint, mimicking the same behaviour as clicking the + # password reset confirm button + + # Confirm the password reset + channel = make_request( + self.reactor, + FakeSite(self.submit_token_resource), + "POST", + path, + content=b"", + shorthand=False, + content_is_form=True, + ) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) + + def _reset_password( + self, new_password, session_id, client_secret, expected_code=200 + ): + channel = self.make_request( + "POST", + b"account/password", + { + "new_password": new_password, + "auth": { + "type": LoginType.EMAIL_IDENTITY, + "threepid_creds": { + "client_secret": client_secret, + "sid": session_id, + }, + }, + }, + ) + self.assertEquals(expected_code, channel.code, channel.result) + + +class DeactivateTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + account.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + return self.hs + + def test_deactivate_account(self): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + self.deactivate(user_id, tok) + + store = self.hs.get_datastore() + + # Check that the user has been marked as deactivated. + self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) + + # Check that this access token has been invalidated. + channel = self.make_request("GET", "account/whoami", access_token=tok) + self.assertEqual(channel.code, 401) + + def test_pending_invites(self): + """Tests that deactivating a user rejects every pending invite for them.""" + store = self.hs.get_datastore() + + inviter_id = self.register_user("inviter", "test") + inviter_tok = self.login("inviter", "test") + + invitee_id = self.register_user("invitee", "test") + invitee_tok = self.login("invitee", "test") + + # Make @inviter:test invite @invitee:test in a new room. + room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok) + self.helper.invite( + room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok + ) + + # Make sure the invite is here. + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) + self.assertEqual(len(pending_invites), 1, pending_invites) + self.assertEqual(pending_invites[0].room_id, room_id, pending_invites) + + # Deactivate @invitee:test. + self.deactivate(invitee_id, invitee_tok) + + # Check that the invite isn't there anymore. + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) + self.assertEqual(len(pending_invites), 0, pending_invites) + + # Check that the membership of @invitee:test in the room is now "leave". + memberships = self.get_success( + store.get_rooms_for_local_user_where_membership_is( + invitee_id, [Membership.LEAVE] + ) + ) + self.assertEqual(len(memberships), 1, memberships) + self.assertEqual(memberships[0].room_id, room_id, memberships) + + def deactivate(self, user_id, tok): + request_data = json.dumps( + { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "test", + }, + "erase": False, + } + ) + channel = self.make_request( + "POST", "account/deactivate", request_data, access_token=tok + ) + self.assertEqual(channel.code, 200) + + +class WhoamiTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + account.register_servlets, + register.register_servlets, + ] + + def test_GET_whoami(self): + device_id = "wouldgohere" + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test", device_id=device_id) + + whoami = self.whoami(tok) + self.assertEqual(whoami, {"user_id": user_id, "device_id": device_id}) + + def test_GET_whoami_appservices(self): + user_id = "@as:test" + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + self.hs.config.server_name, + id="1234", + namespaces={"users": [{"regex": user_id, "exclusive": True}]}, + sender=user_id, + ) + self.hs.get_datastore().services_cache.append(appservice) + + whoami = self.whoami(as_token) + self.assertEqual(whoami, {"user_id": user_id}) + self.assertFalse(hasattr(whoami, "device_id")) + + def whoami(self, tok): + channel = self.make_request("GET", "account/whoami", {}, access_token=tok) + self.assertEqual(channel.code, 200) + return channel.json_body + + +class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail + + return self.hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.user_id = self.register_user("kermit", "test") + self.user_id_tok = self.login("kermit", "test") + self.email = "test@example.com" + self.url_3pid = b"account/3pid" + + def test_add_valid_email(self): + self.get_success(self._add_email(self.email, self.email)) + + def test_add_valid_email_second_time(self): + self.get_success(self._add_email(self.email, self.email)) + self.get_success( + self._request_token_invalid_email( + self.email, + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", + ) + ) + + def test_add_valid_email_second_time_canonicalise(self): + self.get_success(self._add_email(self.email, self.email)) + self.get_success( + self._request_token_invalid_email( + "TEST@EXAMPLE.COM", + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", + ) + ) + + def test_add_email_no_at(self): + self.get_success( + self._request_token_invalid_email( + "address-without-at.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) + ) + + def test_add_email_two_at(self): + self.get_success( + self._request_token_invalid_email( + "foo@foo@test.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) + ) + + def test_add_email_bad_format(self): + self.get_success( + self._request_token_invalid_email( + "user@bad.example.net@good.example.com", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) + ) + + def test_add_email_domain_to_lower(self): + self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) + + def test_add_email_domain_with_umlaut(self): + self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")) + + def test_add_email_address_casefold(self): + self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) + + def test_address_trim(self): + self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) + + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_ip(self): + """Tests that adding emails is ratelimited by IP""" + + # We expect to be able to set three emails before getting ratelimited. + self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) + self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) + self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) + + with self.assertRaises(HttpResponseException) as cm: + self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) + + self.assertEqual(cm.exception.code, 429) + + def test_add_email_if_disabled(self): + """Test adding email to profile when doing so is disallowed""" + self.hs.config.enable_3pid_changes = False + + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email(self): + """Test deleting an email from profile""" + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email_if_disabled(self): + """Test deleting an email from profile when disallowed""" + self.hs.config.enable_3pid_changes = False + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_cant_add_email_without_clicking_link(self): + """Test that we do actually need to click the link in the email""" + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to add email without clicking the link + channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + client_secret = "foobar" + session_id = "weasle" + + # Attempt to add email without even requesting an email + channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link(self): + """Tests a valid next_link parameter value with no whitelist (good case)""" + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/good/site", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_exotic_protocol(self): + """Tests using a esoteric protocol as a next_link parameter value. + Someone may be hosting a client on IPFS etc. + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="some-protocol://abcdefghijklmopqrstuvwxyz", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_file_uri(self): + """Tests next_link parameters cannot be file URI""" + # Attempt to use a next_link value that points to the local disk + self._request_token( + "something@example.com", + "some_secret", + next_link="file:///host/path", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) + def test_next_link_domain_whitelist(self): + """Tests next_link parameters must fit the whitelist if provided""" + + # Ensure not providing a next_link parameter still works + self._request_token( + "something@example.com", + "some_secret", + next_link=None, + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/some/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.org/some/also/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://bad.example.org/some/bad/page", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": []}) + def test_empty_next_link_domain_whitelist(self): + """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially + disallowed + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/page", + expect_code=400, + ) + + def _request_token( + self, + email: str, + client_secret: str, + next_link: Optional[str] = None, + expect_code: int = 200, + ) -> str: + """Request a validation token to add an email address to a user's account + + Args: + email: The email address to validate + client_secret: A secret string + next_link: A link to redirect the user to after validation + expect_code: Expected return code of the call + + Returns: + The ID of the new threepid validation session + """ + body = {"client_secret": client_secret, "email": email, "send_attempt": 1} + if next_link: + body["next_link"] = next_link + + channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + body, + ) + + if channel.code != expect_code: + raise HttpResponseException( + channel.code, + channel.result["reason"], + channel.result["body"], + ) + + return channel.json_body.get("sid") + + def _request_token_invalid_email( + self, + email, + expected_errcode, + expected_error, + client_secret="foobar", + ): + channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(expected_errcode, channel.json_body["errcode"]) + self.assertEqual(expected_error, channel.json_body["error"]) + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + channel = self.make_request("GET", path, shorthand=False) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) + + def _add_email(self, request_email, expected_email): + """Test adding an email to profile""" + previous_email_attempts = len(self.email_attempts) + + client_secret = "foobar" + session_id = self._request_token(request_email, client_secret) + + self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) + link = self._get_link_from_email() + + self._validate_token(link) + + channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + + threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} + self.assertIn(expected_email, threepids) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py new file mode 100644 index 0000000000..e2fcbdc63a --- /dev/null +++ b/tests/rest/client/test_auth.py @@ -0,0 +1,717 @@ +# Copyright 2018 New Vector +# Copyright 2020-2021 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +from twisted.internet.defer import succeed + +import synapse.rest.admin +from synapse.api.constants import LoginType +from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker +from synapse.rest.client import account, auth, devices, login, register +from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.types import JsonDict, UserID + +from tests import unittest +from tests.handlers.test_oidc import HAS_OIDC +from tests.rest.client.utils import TEST_OIDC_CONFIG +from tests.server import FakeChannel +from tests.unittest import override_config, skip_unless + + +class DummyRecaptchaChecker(UserInteractiveAuthChecker): + def __init__(self, hs): + super().__init__(hs) + self.recaptcha_attempts = [] + + def check_auth(self, authdict, clientip): + self.recaptcha_attempts.append((authdict, clientip)) + return succeed(True) + + +class FallbackAuthTests(unittest.HomeserverTestCase): + + servlets = [ + auth.register_servlets, + register.register_servlets, + ] + hijack_auth = False + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + + config["enable_registration_captcha"] = True + config["recaptcha_public_key"] = "brokencake" + config["registrations_require_3pid"] = [] + + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor, clock, hs): + self.recaptcha_checker = DummyRecaptchaChecker(hs) + auth_handler = hs.get_auth_handler() + auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker + + def register(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Make a register request.""" + channel = self.make_request("POST", "register", body) + + self.assertEqual(channel.code, expected_response) + return channel + + def recaptcha( + self, + session: str, + expected_post_response: int, + post_session: Optional[str] = None, + ) -> None: + """Get and respond to a fallback recaptcha. Returns the second request.""" + if post_session is None: + post_session = session + + channel = self.make_request( + "GET", "auth/m.login.recaptcha/fallback/web?session=" + session + ) + self.assertEqual(channel.code, 200) + + channel = self.make_request( + "POST", + "auth/m.login.recaptcha/fallback/web?session=" + + post_session + + "&g-recaptcha-response=a", + ) + self.assertEqual(channel.code, expected_post_response) + + # The recaptcha handler is called with the response given + attempts = self.recaptcha_checker.recaptcha_attempts + self.assertEqual(len(attempts), 1) + self.assertEqual(attempts[0][0]["response"], "a") + + def test_fallback_captcha(self): + """Ensure that fallback auth via a captcha works.""" + # Returns a 401 as per the spec + channel = self.register( + 401, + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) + + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + # Complete the recaptcha step. + self.recaptcha(session, 200) + + # also complete the dummy auth + self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) + + # Now we should have fulfilled a complete auth flow, including + # the recaptcha fallback step, we can then send a + # request to the register API with the session in the authdict. + channel = self.register(200, {"auth": {"session": session}}) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") + + def test_complete_operation_unknown_session(self): + """ + Attempting to mark an invalid session as complete should error. + """ + # Make the initial request to register. (Later on a different password + # will be used.) + # Returns a 401 as per the spec + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"} + ) + + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + # Attempt to complete the recaptcha step with an unknown session. + # This results in an error. + self.recaptcha(session, 400, session + "unknown") + + +class UIAuthTests(unittest.HomeserverTestCase): + servlets = [ + auth.register_servlets, + devices.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + ] + + def default_config(self): + config = super().default_config() + + # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns + # False, so synapse will see the requested uri as http://..., so using http in + # the public_baseurl stops Synapse trying to redirect to https. + config["public_baseurl"] = "http://synapse.test" + + if HAS_OIDC: + # we enable OIDC as a way of testing SSO flows + oidc_config = {} + oidc_config.update(TEST_OIDC_CONFIG) + oidc_config["allow_existing_users"] = True + config["oidc_config"] = oidc_config + + return config + + def create_resource_dict(self): + resource_dict = super().create_resource_dict() + resource_dict.update(build_synapse_client_resource_tree(self.hs)) + return resource_dict + + def prepare(self, reactor, clock, hs): + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + self.device_id = "dev1" + self.user_tok = self.login("test", self.user_pass, self.device_id) + + def delete_device( + self, + access_token: str, + device: str, + expected_response: int, + body: Union[bytes, JsonDict] = b"", + ) -> FakeChannel: + """Delete an individual device.""" + channel = self.make_request( + "DELETE", + "devices/" + device, + body, + access_token=access_token, + ) + + # Ensure the response is sane. + self.assertEqual(channel.code, expected_response) + + return channel + + def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Delete 1 or more devices.""" + # Note that this uses the delete_devices endpoint so that we can modify + # the payload half-way through some tests. + channel = self.make_request( + "POST", + "delete_devices", + body, + access_token=self.user_tok, + ) + + # Ensure the response is sane. + self.assertEqual(channel.code, expected_response) + + return channel + + def test_ui_auth(self): + """ + Test user interactive authentication outside of registration. + """ + # Attempt to delete this device. + # Returns a 401 as per the spec + channel = self.delete_device(self.user_tok, self.device_id, 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + self.device_id, + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_grandfathered_identifier(self): + """Check behaviour without "identifier" dict + + Synapse used to require clients to submit a "user" field for m.login.password + UIA - check that still works. + """ + + channel = self.delete_device(self.user_tok, self.device_id, 401) + session = channel.json_body["session"] + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + self.device_id, + 200, + { + "auth": { + "type": "m.login.password", + "user": self.user, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_can_change_body(self): + """ + The client dict can be modified during the user interactive authentication session. + + Note that it is not spec compliant to modify the client dict during a + user interactive authentication session, but many clients currently do. + + When Synapse is updated to be spec compliant, the call to re-use the + session ID should be rejected. + """ + # Create a second login. + self.login("test", self.user_pass, "dev2") + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_devices(401, {"devices": [self.device_id]}) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. + self.delete_devices( + 200, + { + "devices": ["dev2"], + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_cannot_change_uri(self): + """ + The initial requested URI cannot be modified during the user interactive authentication session. + """ + # Create a second login. + self.login("test", self.user_pass, "dev2") + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_device(self.user_tok, self.device_id, 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. This results in an error. + # + # This makes use of the fact that the device ID is embedded into the URL. + self.delete_device( + self.user_tok, + "dev2", + 403, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + @unittest.override_config({"ui_auth": {"session_timeout": "5s"}}) + def test_can_reuse_session(self): + """ + The session can be reused if configured. + + Compare to test_cannot_change_uri. + """ + # Create a second and third login. + self.login("test", self.user_pass, "dev2") + self.login("test", self.user_pass, "dev3") + + # Attempt to delete a device. This works since the user just logged in. + self.delete_device(self.user_tok, "dev2", 200) + + # Move the clock forward past the validation timeout. + self.reactor.advance(6) + + # Deleting another devices throws the user into UI auth. + channel = self.delete_device(self.user_tok, "dev3", 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + "dev3", + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + # Make another request, but try to delete the first device. This works + # due to re-using the previous session. + # + # Note that *no auth* information is provided, not even a session iD! + self.delete_device(self.user_tok, self.device_id, 200) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_ui_auth_via_sso(self): + """Test a successful UI Auth flow via SSO + + This includes: + * hitting the UIA SSO redirect endpoint + * checking it serves a confirmation page which links to the OIDC provider + * calling back to the synapse oidc callback + * checking that the original operation succeeds + """ + + # log the user in + remote_user_id = UserID.from_string(self.user).localpart + login_resp = self.helper.login_via_oidc(remote_user_id) + self.assertEqual(login_resp["user_id"], self.user) + + # initiate a UI Auth process by attempting to delete the device + channel = self.delete_device(self.user_tok, self.device_id, 401) + + # check that SSO is offered + flows = channel.json_body["flows"] + self.assertIn({"stages": ["m.login.sso"]}, flows) + + # run the UIA-via-SSO flow + session_id = channel.json_body["session"] + channel = self.helper.auth_via_oidc( + {"sub": remote_user_id}, ui_auth_session_id=session_id + ) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + + # and now the delete request should succeed. + self.delete_device( + self.user_tok, + self.device_id, + 200, + body={"auth": {"session": session_id}}, + ) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_does_not_offer_password_for_sso_user(self): + login_resp = self.helper.login_via_oidc("username") + user_tok = login_resp["access_token"] + device_id = login_resp["device_id"] + + # now call the device deletion API: we should get the option to auth with SSO + # and not password. + channel = self.delete_device(user_tok, device_id, 401) + + flows = channel.json_body["flows"] + self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) + + def test_does_not_offer_sso_for_password_user(self): + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + self.assertEqual(flows, [{"stages": ["m.login.password"]}]) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_offers_both_flows_for_upgraded_user(self): + """A user that had a password and then logged in with SSO should get both flows""" + login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + self.assertEqual(login_resp["user_id"], self.user) + + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + # we have no particular expectations of ordering here + self.assertIn({"stages": ["m.login.password"]}, flows) + self.assertIn({"stages": ["m.login.sso"]}, flows) + self.assertEqual(len(flows), 2) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_ui_auth_fails_for_incorrect_sso_user(self): + """If the user tries to authenticate with the wrong SSO user, they get an error""" + # log the user in + login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + self.assertEqual(login_resp["user_id"], self.user) + + # start a UI Auth flow by attempting to delete a device + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + self.assertIn({"stages": ["m.login.sso"]}, flows) + session_id = channel.json_body["session"] + + # do the OIDC auth, but auth as the wrong user + channel = self.helper.auth_via_oidc( + {"sub": "wrong_user"}, ui_auth_session_id=session_id + ) + + # that should return a failure message + self.assertSubstring("We were unable to validate", channel.text_body) + + # ... and the delete op should now fail with a 403 + self.delete_device( + self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} + ) + + +class RefreshAuthTests(unittest.HomeserverTestCase): + servlets = [ + auth.register_servlets, + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + ] + hijack_auth = False + + def prepare(self, reactor, clock, hs): + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + + def test_login_issue_refresh_token(self): + """ + A login response should include a refresh_token only if asked. + """ + # Test login + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + + login_without_refresh = self.make_request( + "POST", "/_matrix/client/r0/login", body + ) + self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) + self.assertNotIn("refresh_token", login_without_refresh.json_body) + + login_with_refresh = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) + self.assertIn("refresh_token", login_with_refresh.json_body) + self.assertIn("expires_in_ms", login_with_refresh.json_body) + + def test_register_issue_refresh_token(self): + """ + A register response should include a refresh_token only if asked. + """ + register_without_refresh = self.make_request( + "POST", + "/_matrix/client/r0/register", + { + "username": "test2", + "password": self.user_pass, + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual( + register_without_refresh.code, 200, register_without_refresh.result + ) + self.assertNotIn("refresh_token", register_without_refresh.json_body) + + register_with_refresh = self.make_request( + "POST", + "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true", + { + "username": "test3", + "password": self.user_pass, + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) + self.assertIn("refresh_token", register_with_refresh.json_body) + self.assertIn("expires_in_ms", register_with_refresh.json_body) + + def test_token_refresh(self): + """ + A refresh token can be used to issue a new access token. + """ + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + + refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertIn("access_token", refresh_response.json_body) + self.assertIn("refresh_token", refresh_response.json_body) + self.assertIn("expires_in_ms", refresh_response.json_body) + + # The access and refresh tokens should be different from the original ones after refresh + self.assertNotEqual( + login_response.json_body["access_token"], + refresh_response.json_body["access_token"], + ) + self.assertNotEqual( + login_response.json_body["refresh_token"], + refresh_response.json_body["refresh_token"], + ) + + @override_config({"access_token_lifetime": "1m"}) + def test_refresh_token_expiration(self): + """ + The access token should have some time as specified in the config. + """ + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + self.assertApproximates( + login_response.json_body["expires_in_ms"], 60 * 1000, 100 + ) + + refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertApproximates( + refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 + ) + + def test_refresh_token_invalidation(self): + """Refresh tokens are invalidated after first use of the next token. + + A refresh token is considered invalid if: + - it was already used at least once + - and either + - the next access token was used + - the next refresh token was used + + The chain of tokens goes like this: + + login -|-> first_refresh -> third_refresh (fails) + |-> second_refresh -> fifth_refresh + |-> fourth_refresh (fails) + """ + + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + + # This first refresh should work properly + first_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + first_refresh_response.code, 200, first_refresh_response.result + ) + + # This one as well, since the token in the first one was never used + second_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + second_refresh_response.code, 200, second_refresh_response.result + ) + + # This one should not, since the token from the first refresh is not valid anymore + third_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": first_refresh_response.json_body["refresh_token"]}, + ) + self.assertEqual( + third_refresh_response.code, 401, third_refresh_response.result + ) + + # The associated access token should also be invalid + whoami_response = self.make_request( + "GET", + "/_matrix/client/r0/account/whoami", + access_token=first_refresh_response.json_body["access_token"], + ) + self.assertEqual(whoami_response.code, 401, whoami_response.result) + + # But all other tokens should work (they will expire after some time) + for access_token in [ + second_refresh_response.json_body["access_token"], + login_response.json_body["access_token"], + ]: + whoami_response = self.make_request( + "GET", "/_matrix/client/r0/account/whoami", access_token=access_token + ) + self.assertEqual(whoami_response.code, 200, whoami_response.result) + + # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail + fourth_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + fourth_refresh_response.code, 403, fourth_refresh_response.result + ) + + # But refreshing from the last valid refresh token still works + fifth_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": second_refresh_response.json_body["refresh_token"]}, + ) + self.assertEqual( + fifth_refresh_response.code, 200, fifth_refresh_response.result + ) diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py new file mode 100644 index 0000000000..ac31e5ceaf --- /dev/null +++ b/tests/rest/client/test_capabilities.py @@ -0,0 +1,212 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse.rest.admin +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.rest.client import capabilities, login + +from tests import unittest +from tests.unittest import override_config + + +class CapabilitiesTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + capabilities.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.url = b"/_matrix/client/r0/capabilities" + hs = self.setup_test_homeserver() + self.config = hs.config + self.auth_handler = hs.get_auth_handler() + return hs + + def prepare(self, reactor, clock, hs): + self.localpart = "user" + self.password = "pass" + self.user = self.register_user(self.localpart, self.password) + + def test_check_auth_required(self): + channel = self.make_request("GET", self.url) + + self.assertEqual(channel.code, 401) + + def test_get_room_version_capabilities(self): + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + for room_version in capabilities["m.room_versions"]["available"].keys(): + self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) + + self.assertEqual( + self.config.default_room_version.identifier, + capabilities["m.room_versions"]["default"], + ) + + def test_get_change_password_capabilities_password_login(self): + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + + @override_config({"password_config": {"localdb_enabled": False}}) + def test_get_change_password_capabilities_localdb_disabled(self): + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["m.change_password"]["enabled"]) + + @override_config({"password_config": {"enabled": False}}) + def test_get_change_password_capabilities_password_disabled(self): + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["m.change_password"]["enabled"]) + + def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self): + """Test that per default msc3283 is disabled server returns `m.change_password`.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities) + self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities) + self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities) + + @override_config({"experimental_features": {"msc3283_enabled": True}}) + def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self): + """Test if msc3283 is enabled server returns capabilities.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + + @override_config( + { + "enable_set_displayname": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_get_set_displayname_capabilities_displayname_disabled(self): + """Test if set displayname is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) + + @override_config( + { + "enable_set_avatar_url": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): + """Test if set avatar_url is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) + + @override_config( + { + "enable_3pid_changes": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_change_3pid_capabilities_3pid_disabled(self): + """Test if change 3pid is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + + def test_get_does_not_include_msc3244_fields_by_default(self): + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertNotIn( + "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] + ) + + @override_config({"experimental_features": {"msc3244_enabled": True}}) + def test_get_does_include_msc3244_fields_when_enabled(self): + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + for details in capabilities["m.room_versions"][ + "org.matrix.msc3244.room_capabilities" + ].values(): + if details["preferred"] is not None: + self.assertTrue( + details["preferred"] in KNOWN_ROOM_VERSIONS, + str(details["preferred"]), + ) + + self.assertGreater(len(details["support"]), 0) + for room_version in details["support"]: + self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version)) diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py new file mode 100644 index 0000000000..d2181ea907 --- /dev/null +++ b/tests/rest/client/test_directory.py @@ -0,0 +1,154 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.rest import admin +from synapse.rest.client import directory, login, room +from synapse.types import RoomAlias +from synapse.util.stringutils import random_string + +from tests import unittest +from tests.unittest import override_config + + +class DirectoryTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["require_membership_for_aliases"] = True + + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.user = self.register_user("user", "test") + self.user_tok = self.login("user", "test") + + def test_state_event_not_in_room(self): + self.ensure_user_left_room() + self.set_alias_via_state_event(403) + + def test_directory_endpoint_not_in_room(self): + self.ensure_user_left_room() + self.set_alias_via_directory(403) + + def test_state_event_in_room_too_long(self): + self.ensure_user_joined_room() + self.set_alias_via_state_event(400, alias_length=256) + + def test_directory_in_room_too_long(self): + self.ensure_user_joined_room() + self.set_alias_via_directory(400, alias_length=256) + + @override_config({"default_room_version": 5}) + def test_state_event_user_in_v5_room(self): + """Test that a regular user can add alias events before room v6""" + self.ensure_user_joined_room() + self.set_alias_via_state_event(200) + + @override_config({"default_room_version": 6}) + def test_state_event_v6_room(self): + """Test that a regular user can *not* add alias events from room v6""" + self.ensure_user_joined_room() + self.set_alias_via_state_event(403) + + def test_directory_in_room(self): + self.ensure_user_joined_room() + self.set_alias_via_directory(200) + + def test_room_creation_too_long(self): + url = "/_matrix/client/r0/createRoom" + + # We use deliberately a localpart under the length threshold so + # that we can make sure that the check is done on the whole alias. + data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} + request_data = json.dumps(data) + channel = self.make_request( + "POST", url, request_data, access_token=self.user_tok + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_room_creation(self): + url = "/_matrix/client/r0/createRoom" + + # Check with an alias of allowed length. There should already be + # a test that ensures it works in test_register.py, but let's be + # as cautious as possible here. + data = {"room_alias_name": random_string(5)} + request_data = json.dumps(data) + channel = self.make_request( + "POST", url, request_data, access_token=self.user_tok + ) + self.assertEqual(channel.code, 200, channel.result) + + def set_alias_via_state_event(self, expected_code, alias_length=5): + url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( + self.room_id, + self.hs.hostname, + ) + + data = {"aliases": [self.random_alias(alias_length)]} + request_data = json.dumps(data) + + channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.assertEqual(channel.code, expected_code, channel.result) + + def set_alias_via_directory(self, expected_code, alias_length=5): + url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.assertEqual(channel.code, expected_code, channel.result) + + def random_alias(self, length): + return RoomAlias(random_string(length), self.hs.hostname).to_string() + + def ensure_user_left_room(self): + self.ensure_membership("leave") + + def ensure_user_joined_room(self): + self.ensure_membership("join") + + def ensure_membership(self, membership): + try: + if membership == "leave": + self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) + if membership == "join": + self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok) + except AssertionError: + # We don't care whether the leave request didn't return a 200 (e.g. + # if the user isn't already in the room), because we only want to + # make sure the user isn't in the room. + pass diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py new file mode 100644 index 0000000000..a90294003e --- /dev/null +++ b/tests/rest/client/test_events.py @@ -0,0 +1,156 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Tests REST events for /events paths.""" + +from unittest.mock import Mock + +import synapse.rest.admin +from synapse.rest.client import events, login, room + +from tests import unittest + + +class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): + """Tests event streaming (GET /events).""" + + servlets = [ + events.register_servlets, + room.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + config["enable_registration_captcha"] = False + config["enable_registration"] = True + config["auto_join_rooms"] = [] + + hs = self.setup_test_homeserver(config=config) + + hs.get_federation_handler = Mock() + + return hs + + def prepare(self, reactor, clock, hs): + + # register an account + self.user_id = self.register_user("sid1", "pass") + self.token = self.login(self.user_id, "pass") + + # register a 2nd account + self.other_user = self.register_user("other2", "pass") + self.other_token = self.login(self.other_user, "pass") + + def test_stream_basic_permissions(self): + # invalid token, expect 401 + # note: this is in violation of the original v1 spec, which expected + # 403. However, since the v1 spec no longer exists and the v1 + # implementation is now part of the r0 implementation, the newer + # behaviour is used instead to be consistent with the r0 spec. + # see issue #2602 + channel = self.make_request( + "GET", "/events?access_token=%s" % ("invalid" + self.token,) + ) + self.assertEquals(channel.code, 401, msg=channel.result) + + # valid token, expect content + channel = self.make_request( + "GET", "/events?access_token=%s&timeout=0" % (self.token,) + ) + self.assertEquals(channel.code, 200, msg=channel.result) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("start" in channel.json_body) + self.assertTrue("end" in channel.json_body) + + def test_stream_room_permissions(self): + room_id = self.helper.create_room_as(self.other_user, tok=self.other_token) + self.helper.send(room_id, tok=self.other_token) + + # invited to room (expect no content for room) + self.helper.invite( + room_id, src=self.other_user, targ=self.user_id, tok=self.other_token + ) + + # valid token, expect content + channel = self.make_request( + "GET", "/events?access_token=%s&timeout=0" % (self.token,) + ) + self.assertEquals(channel.code, 200, msg=channel.result) + + # We may get a presence event for ourselves down + self.assertEquals( + 0, + len( + [ + c + for c in channel.json_body["chunk"] + if not ( + c.get("type") == "m.presence" + and c["content"].get("user_id") == self.user_id + ) + ] + ), + ) + + # joined room (expect all content for room) + self.helper.join(room=room_id, user=self.user_id, tok=self.token) + + # left to room (expect no content for room) + + def TODO_test_stream_items(self): + # new user, no content + + # join room, expect 1 item (join) + + # send message, expect 2 items (join,send) + + # set topic, expect 3 items (join,send,topic) + + # someone else join room, expect 4 (join,send,topic,join) + + # someone else send message, expect 5 (join,send.topic,join,send) + + # someone else set topic, expect 6 (join,send,topic,join,send,topic) + pass + + +class GetEventsTestCase(unittest.HomeserverTestCase): + servlets = [ + events.register_servlets, + room.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def prepare(self, hs, reactor, clock): + + # register an account + self.user_id = self.register_user("sid1", "pass") + self.token = self.login(self.user_id, "pass") + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + def test_get_event_via_events(self): + resp = self.helper.send(self.room_id, tok=self.token) + event_id = resp["event_id"] + + channel = self.make_request( + "GET", + "/events/" + event_id, + access_token=self.token, + ) + self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py new file mode 100644 index 0000000000..475c6bed3d --- /dev/null +++ b/tests/rest/client/test_filter.py @@ -0,0 +1,111 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import Codes +from synapse.rest.client import filter + +from tests import unittest + +PATH_PREFIX = "/_matrix/client/v2_alpha" + + +class FilterTestCase(unittest.HomeserverTestCase): + + user_id = "@apple:test" + hijack_auth = True + EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} + EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' + servlets = [filter.register_servlets] + + def prepare(self, reactor, clock, hs): + self.filtering = hs.get_filtering() + self.store = hs.get_datastore() + + def test_add_filter(self): + channel = self.make_request( + "POST", + "/_matrix/client/r0/user/%s/filter" % (self.user_id), + self.EXAMPLE_FILTER_JSON, + ) + + self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.json_body, {"filter_id": "0"}) + filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) + self.pump() + self.assertEquals(filter.result, self.EXAMPLE_FILTER) + + def test_add_filter_for_other_user(self): + channel = self.make_request( + "POST", + "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), + self.EXAMPLE_FILTER_JSON, + ) + + self.assertEqual(channel.result["code"], b"403") + self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + + def test_add_filter_non_local_user(self): + _is_mine = self.hs.is_mine + self.hs.is_mine = lambda target_user: False + channel = self.make_request( + "POST", + "/_matrix/client/r0/user/%s/filter" % (self.user_id), + self.EXAMPLE_FILTER_JSON, + ) + + self.hs.is_mine = _is_mine + self.assertEqual(channel.result["code"], b"403") + self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + + def test_get_filter(self): + filter_id = defer.ensureDeferred( + self.filtering.add_user_filter( + user_localpart="apple", user_filter=self.EXAMPLE_FILTER + ) + ) + self.reactor.advance(1) + filter_id = filter_id.result + channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) + ) + + self.assertEqual(channel.result["code"], b"200") + self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) + + def test_get_filter_non_existant(self): + channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) + ) + + self.assertEqual(channel.result["code"], b"404") + self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) + + # Currently invalid params do not have an appropriate errcode + # in errors.py + def test_get_filter_invalid_id(self): + channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) + ) + + self.assertEqual(channel.result["code"], b"400") + + # No ID also returns an invalid_id error + def test_get_filter_no_id(self): + channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) + ) + + self.assertEqual(channel.result["code"], b"400") diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py new file mode 100644 index 0000000000..d7fa635eae --- /dev/null +++ b/tests/rest/client/test_keys.py @@ -0,0 +1,91 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from http import HTTPStatus + +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client import keys, login + +from tests import unittest + + +class KeyQueryTestCase(unittest.HomeserverTestCase): + servlets = [ + keys.register_servlets, + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def test_rejects_device_id_ice_key_outside_of_list(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: "device_id1", + }, + }, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_rejects_device_key_given_as_map_to_bool(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: { + "device_id1": True, + }, + }, + }, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_requires_device_key(self): + """`device_keys` is required. We should complain if it's missing.""" + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + {}, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py new file mode 100644 index 0000000000..5b2243fe52 --- /dev/null +++ b/tests/rest/client/test_login.py @@ -0,0 +1,1345 @@ +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import urllib.parse +from typing import Any, Dict, List, Optional, Union +from unittest.mock import Mock +from urllib.parse import urlencode + +import pymacaroons + +from twisted.web.resource import Resource + +import synapse.rest.admin +from synapse.appservice import ApplicationService +from synapse.rest.client import devices, login, logout, register +from synapse.rest.client.account import WhoamiRestServlet +from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.types import create_requester + +from tests import unittest +from tests.handlers.test_oidc import HAS_OIDC +from tests.handlers.test_saml import has_saml2 +from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.test_utils.html_parsers import TestHtmlParser +from tests.unittest import HomeserverTestCase, override_config, skip_unless + +try: + import jwt + + HAS_JWT = True +except ImportError: + HAS_JWT = False + + +# synapse server name: used to populate public_baseurl in some tests +SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" + +# public_baseurl for some tests. It uses an http:// scheme because +# FakeChannel.isSecure() returns False, so synapse will see the requested uri as +# http://..., so using http in the public_baseurl stops Synapse trying to redirect to +# https://.... +BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) + +# CAS server used in some tests +CAS_SERVER = "https://fake.test" + +# just enough to tell pysaml2 where to redirect to +SAML_SERVER = "https://test.saml.server/idp/sso" +TEST_SAML_METADATA = """ + + + + + +""" % { + "SAML_SERVER": SAML_SERVER, +} + +LOGIN_URL = b"/_matrix/client/r0/login" +TEST_URL = b"/_matrix/client/r0/account/whoami" + +# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + +TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' + +# the query params in TEST_CLIENT_REDIRECT_URL +EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] + +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + + +class LoginRestServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + logout.register_servlets, + devices.register_servlets, + lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.enable_registration = True + self.hs.config.registrations_require_3pid = [] + self.hs.config.auto_join_rooms = [] + self.hs.config.enable_registration_captcha = False + + return self.hs + + @override_config( + { + "rc_login": { + "address": {"per_second": 0.17, "burst_count": 5}, + # Prevent the account login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "account": {"per_second": 10000, "burst_count": 10000}, + } + } + ) + def test_POST_ratelimiting_per_address(self): + # Create different users so we're sure not to be bothered by the per-user + # ratelimiter. + for i in range(0, 6): + self.register_user("kermit" + str(i), "monkey") + + for i in range(0, 6): + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower + # than 1min. + self.assertTrue(retry_after_ms < 6000) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config( + { + "rc_login": { + "account": {"per_second": 0.17, "burst_count": 5}, + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + } + } + ) + def test_POST_ratelimiting_per_account(self): + self.register_user("kermit", "monkey") + + for i in range(0, 6): + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower + # than 1min. + self.assertTrue(retry_after_ms < 6000) + + self.reactor.advance(retry_after_ms / 1000.0) + + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config( + { + "rc_login": { + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + "failed_attempts": {"per_second": 0.17, "burst_count": 5}, + } + } + ) + def test_POST_ratelimiting_per_account_failed_attempts(self): + self.register_user("kermit", "monkey") + + for i in range(0, 6): + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "notamonkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"403", channel.result) + + # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower + # than 1min. + self.assertTrue(retry_after_ms < 6000) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "notamonkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.result["code"], b"403", channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_soft_logout(self): + self.register_user("kermit", "monkey") + + # we shouldn't be able to make requests without an access token + channel = self.make_request(b"GET", TEST_URL) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") + + # log in as normal + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.code, 200, channel.result) + access_token = channel.json_body["access_token"] + device_id = channel.json_body["device_id"] + + # we should now be able to make requests with the access token + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # + # test behaviour after deleting the expired device + # + + # we now log in as a different device + access_token_2 = self.login("kermit", "monkey") + + # more requests with the expired token should still return a soft-logout + self.reactor.advance(3600) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # ... but if we delete that device, it will be a proper logout + self._delete_device(access_token_2, "kermit", "monkey", device_id) + + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], False) + + def _delete_device(self, access_token, user_id, password, device_id): + """Perform the UI-Auth to delete a device""" + channel = self.make_request( + b"DELETE", "devices/" + device_id, access_token=access_token + ) + self.assertEquals(channel.code, 401, channel.result) + # check it's a UI-Auth fail + self.assertEqual( + set(channel.json_body.keys()), + {"flows", "params", "session"}, + channel.result, + ) + + auth = { + "type": "m.login.password", + # https://github.com/matrix-org/synapse/issues/5665 + # "identifier": {"type": "m.id.user", "user": user_id}, + "user": user_id, + "password": password, + "session": channel.json_body["session"], + } + + channel = self.make_request( + b"DELETE", + "devices/" + device_id, + access_token=access_token, + content={"auth": auth}, + ) + self.assertEquals(channel.code, 200, channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard logout this session + channel = self.make_request(b"POST", "/logout", access_token=access_token) + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard log out all of the user's sessions + channel = self.make_request(b"POST", "/logout/all", access_token=access_token) + self.assertEquals(channel.result["code"], b"200", channel.result) + + +@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") +class MultiSSOTestCase(unittest.HomeserverTestCase): + """Tests for homeservers with multiple SSO providers enabled""" + + servlets = [ + login.register_servlets, + ] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + + config["public_baseurl"] = BASE_URL + + config["cas_config"] = { + "enabled": True, + "server_url": CAS_SERVER, + "service_url": "https://matrix.goodserver.com:8448", + } + + config["saml2_config"] = { + "sp_config": { + "metadata": {"inline": [TEST_SAML_METADATA]}, + # use the XMLSecurity backend to avoid relying on xmlsec1 + "crypto_backend": "XMLSecurity", + }, + } + + # default OIDC provider + config["oidc_config"] = TEST_OIDC_CONFIG + + # additional OIDC providers + config["oidc_providers"] = [ + { + "idp_id": "idp1", + "idp_name": "IDP1", + "discover": False, + "issuer": "https://issuer1", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": "https://issuer1/auth", + "token_endpoint": "https://issuer1/token", + "userinfo_endpoint": "https://issuer1/userinfo", + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.sub }}"} + }, + } + ] + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d.update(build_synapse_client_resource_tree(self.hs)) + return d + + def test_get_login_flows(self): + """GET /login should return password and SSO flows""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + expected_flow_types = [ + "m.login.cas", + "m.login.sso", + "m.login.token", + "m.login.password", + ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS] + + self.assertCountEqual( + [f["type"] for f in channel.json_body["flows"]], expected_flow_types + ) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_get_msc2858_login_flows(self): + """The SSO flow should include IdP info if MSC2858 is enabled""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + # stick the flows results in a dict by type + flow_results: Dict[str, Any] = {} + for f in channel.json_body["flows"]: + flow_type = f["type"] + self.assertNotIn( + flow_type, flow_results, "duplicate flow type %s" % (flow_type,) + ) + flow_results[flow_type] = f + + self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned") + sso_flow = flow_results.pop("m.login.sso") + # we should have a set of IdPs + self.assertCountEqual( + sso_flow["org.matrix.msc2858.identity_providers"], + [ + {"id": "cas", "name": "CAS"}, + {"id": "saml", "name": "SAML"}, + {"id": "oidc-idp1", "name": "IDP1"}, + {"id": "oidc", "name": "OIDC"}, + ], + ) + + # the rest of the flows are simple + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(flow_results.values(), expected_flows) + + def test_multi_sso_redirect(self): + """/login/sso/redirect should redirect to an identity picker""" + # first hit the redirect url, which should redirect to our idp picker + channel = self._make_sso_redirect_request(False, None) + self.assertEqual(channel.code, 302, channel.result) + uri = channel.headers.getRawHeaders("Location")[0] + + # hitting that picker should give us some HTML + channel = self.make_request("GET", uri) + self.assertEqual(channel.code, 200, channel.result) + + # parse the form to check it has fields assumed elsewhere in this class + html = channel.result["body"].decode("utf-8") + p = TestHtmlParser() + p.feed(html) + p.close() + + # there should be a link for each href + returned_idps: List[str] = [] + for link in p.links: + path, query = link.split("?", 1) + self.assertEqual(path, "pick_idp") + params = urllib.parse.parse_qs(query) + self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL]) + returned_idps.append(params["idp"][0]) + + self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) + + def test_multi_sso_redirect_to_cas(self): + """If CAS is chosen, should redirect to the CAS server""" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=cas", + shorthand=False, + ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + cas_uri = location_headers[0] + cas_uri_path, cas_uri_query = cas_uri.split("?", 1) + + # it should redirect us to the login page of the cas server + self.assertEqual(cas_uri_path, CAS_SERVER + "/login") + + # check that the redirectUrl is correctly encoded in the service param - ie, the + # place that CAS will redirect to + cas_uri_params = urllib.parse.parse_qs(cas_uri_query) + service_uri = cas_uri_params["service"][0] + _, service_uri_query = service_uri.split("?", 1) + service_uri_params = urllib.parse.parse_qs(service_uri_query) + self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) + + def test_multi_sso_redirect_to_saml(self): + """If SAML is chosen, should redirect to the SAML server""" + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=saml", + ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + saml_uri = location_headers[0] + saml_uri_path, saml_uri_query = saml_uri.split("?", 1) + + # it should redirect us to the login page of the SAML server + self.assertEqual(saml_uri_path, SAML_SERVER) + + # the RelayState is used to carry the client redirect url + saml_uri_params = urllib.parse.parse_qs(saml_uri_query) + relay_state_param = saml_uri_params["RelayState"][0] + self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) + + def test_login_via_oidc(self): + """If OIDC is chosen, should redirect to the OIDC auth endpoint""" + + # pick the default OIDC provider + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=oidc", + ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + oidc_uri = location_headers[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + # ... and should have set a cookie including the redirect url + cookie_headers = channel.headers.getRawHeaders("Set-Cookie") + assert cookie_headers + cookies: Dict[str, str] = {} + for h in cookie_headers: + key, value = h.split(";")[0].split("=", maxsplit=1) + cookies[key] = value + + oidc_session_cookie = cookies["oidc_session"] + macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) + self.assertEqual( + self._get_value_from_macaroon(macaroon, "client_redirect_url"), + TEST_CLIENT_REDIRECT_URL, + ) + + channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + content_type_headers = channel.headers.getRawHeaders("Content-Type") + assert content_type_headers + self.assertTrue(content_type_headers[-1].startswith("text/html")) + p = TestHtmlParser() + p.feed(channel.text_body) + p.close() + + # ... which should contain our redirect link + self.assertEqual(len(p.links), 1) + path, query = p.links[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + login_token = params[2][1] + chan = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@user1:test") + + def test_multi_sso_redirect_to_unknown(self): + """An unknown IdP should cause a 400""" + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_client_idp_redirect_to_unknown(self): + """If the client tries to pick an unknown IdP, return a 404""" + channel = self._make_sso_redirect_request(False, "xxx") + self.assertEqual(channel.code, 404, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + def test_client_idp_redirect_to_oidc(self): + """If the client pick a known IdP, redirect to it""" + channel = self._make_sso_redirect_request(False, "oidc") + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_msc2858_redirect_to_oidc(self): + """Test the unstable API""" + channel = self._make_sso_redirect_request(True, "oidc") + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400""" + channel = self._make_sso_redirect_request(True, "oidc") + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + def _make_sso_redirect_request( + self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None + ): + """Send a request to /_matrix/client/r0/login/sso/redirect + + ... or the unstable equivalent + + ... possibly specifying an IDP provider + """ + endpoint = ( + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect" + if unstable_endpoint + else "/_matrix/client/r0/login/sso/redirect" + ) + if idp_prov is not None: + endpoint += "/" + idp_prov + endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + return self.make_request( + "GET", + endpoint, + custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)], + ) + + @staticmethod + def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: + prefix = key + " = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(prefix): + return caveat.caveat_id[len(prefix) :] + raise ValueError("No %s caveat in macaroon" % (key,)) + + +class CASTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.base_url = "https://matrix.goodserver.com/" + self.redirect_path = "_synapse/client/login/sso/redirect/confirm" + + config = self.default_config() + config["public_baseurl"] = ( + config.get("public_baseurl") or "https://matrix.goodserver.com:8448" + ) + config["cas_config"] = { + "enabled": True, + "server_url": CAS_SERVER, + } + + cas_user_id = "username" + self.user_id = "@%s:test" % cas_user_id + + async def get_raw(uri, args): + """Return an example response payload from a call to the `/proxyValidate` + endpoint of a CAS server, copied from + https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 + + This needs to be returned by an async function (as opposed to set as the + mock's return value) because the corresponding Synapse code awaits on it. + """ + return ( + """ + + + %s + PGTIOU-84678-8a9d... + + https://proxy2/pgtUrl + https://proxy1/pgtUrl + + + + """ + % cas_user_id + ).encode("utf-8") + + mocked_http_client = Mock(spec=["get_raw"]) + mocked_http_client.get_raw.side_effect = get_raw + + self.hs = self.setup_test_homeserver( + config=config, + proxied_http_client=mocked_http_client, + ) + + return self.hs + + def prepare(self, reactor, clock, hs): + self.deactivate_account_handler = hs.get_deactivate_account_handler() + + def test_cas_redirect_confirm(self): + """Tests that the SSO login flow serves a confirmation page before redirecting a + user to the redirect URL. + """ + base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl" + redirect_url = "https://dodgy-site.com/" + + url_parts = list(urllib.parse.urlparse(base_url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) + query.update({"redirectUrl": redirect_url}) + query.update({"ticket": "ticket"}) + url_parts[4] = urllib.parse.urlencode(query) + cas_ticket_url = urllib.parse.urlunparse(url_parts) + + # Get Synapse to call the fake CAS and serve the template. + channel = self.make_request("GET", cas_ticket_url) + + # Test that the response is HTML. + self.assertEqual(channel.code, 200, channel.result) + content_type_header_value = "" + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type_header_value = header[1].decode("utf8") + + self.assertTrue(content_type_header_value.startswith("text/html")) + + # Test that the body isn't empty. + self.assertTrue(len(channel.result["body"]) > 0) + + # And that it contains our redirect link + self.assertIn(redirect_url, channel.result["body"].decode("UTF-8")) + + @override_config( + { + "sso": { + "client_whitelist": [ + "https://legit-site.com/", + "https://other-site.com/", + ] + } + } + ) + def test_cas_redirect_whitelisted(self): + """Tests that the SSO login flow serves a redirect to a whitelisted url""" + self._test_redirect("https://legit-site.com/") + + @override_config({"public_baseurl": "https://example.com"}) + def test_cas_redirect_login_fallback(self): + self._test_redirect("https://example.com/_matrix/static/client/login") + + def _test_redirect(self, redirect_url): + """Tests that the SSO login flow serves a redirect for the given redirect URL.""" + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + channel = self.make_request("GET", cas_ticket_url) + + self.assertEqual(channel.code, 302) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) + + @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) + def test_deactivated_user(self): + """Logging in as a deactivated account should error.""" + redirect_url = "https://legit-site.com/" + + # First login (to create the user). + self._test_redirect(redirect_url) + + # Deactivate the account. + self.get_success( + self.deactivate_account_handler.deactivate_account( + self.user_id, False, create_requester(self.user_id) + ) + ) + + # Request the CAS ticket. + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + channel = self.make_request("GET", cas_ticket_url) + + # Because the user is deactivated they are served an error template. + self.assertEqual(channel.code, 403) + self.assertIn(b"SSO account deactivated", channel.result["body"]) + + +@skip_unless(HAS_JWT, "requires jwt") +class JWTTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + jwt_secret = "secret" + jwt_algorithm = "HS256" + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.jwt_enabled = True + self.hs.config.jwt_secret = self.jwt_secret + self.hs.config.jwt_algorithm = self.jwt_algorithm + return self.hs + + def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) + if isinstance(result, bytes): + return result.decode("ascii") + return result + + def jwt_login(self, *args): + params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} + channel = self.make_request(b"POST", LOGIN_URL, params) + return channel + + def test_login_jwt_valid_registered(self): + self.register_user("kermit", "monkey") + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + def test_login_jwt_valid_unregistered(self): + channel = self.jwt_login({"sub": "frog"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@frog:test") + + def test_login_jwt_invalid_signature(self): + channel = self.jwt_login({"sub": "frog"}, "notsecret") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: Signature verification failed", + ) + + def test_login_jwt_expired(self): + channel = self.jwt_login({"sub": "frog", "exp": 864000}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Signature has expired" + ) + + def test_login_jwt_not_before(self): + now = int(time.time()) + channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: The token is not yet valid (nbf)", + ) + + def test_login_no_sub(self): + channel = self.jwt_login({"username": "root"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual(channel.json_body["error"], "Invalid JWT") + + @override_config( + { + "jwt_config": { + "jwt_enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + "issuer": "test-issuer", + } + } + ) + def test_login_iss(self): + """Test validating the issuer claim.""" + # A valid issuer. + channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + # An invalid issuer. + channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid issuer" + ) + + # Not providing an issuer. + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + 'JWT validation failed: Token is missing the "iss" claim', + ) + + def test_login_iss_no_config(self): + """Test providing an issuer claim without requiring it in the configuration.""" + channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + @override_config( + { + "jwt_config": { + "jwt_enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + "audiences": ["test-audience"], + } + } + ) + def test_login_aud(self): + """Test validating the audience claim.""" + # A valid audience. + channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + # An invalid audience. + channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid audience" + ) + + # Not providing an audience. + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + 'JWT validation failed: Token is missing the "aud" claim', + ) + + def test_login_aud_no_config(self): + """Test providing an audience without requiring it in the configuration.""" + channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid audience" + ) + + def test_login_no_token(self): + params = {"type": "org.matrix.login.jwt"} + channel = self.make_request(b"POST", LOGIN_URL, params) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") + + +# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use +# RSS256, with a public key configured in synapse as "jwt_secret", and tokens +# signed by the private key. +@skip_unless(HAS_JWT, "requires jwt") +class JWTPubKeyTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + ] + + # This key's pubkey is used as the jwt_secret setting of synapse. Valid + # tokens are signed by this and validated using the pubkey. It is generated + # with `openssl genrsa 512` (not a secure way to generate real keys, but + # good enough for tests!) + jwt_privatekey = "\n".join( + [ + "-----BEGIN RSA PRIVATE KEY-----", + "MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB", + "492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk", + "yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/", + "kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq", + "TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN", + "ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA", + "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=", + "-----END RSA PRIVATE KEY-----", + ] + ) + + # Generated with `openssl rsa -in foo.key -pubout`, with the the above + # private key placed in foo.key (jwt_privatekey). + jwt_pubkey = "\n".join( + [ + "-----BEGIN PUBLIC KEY-----", + "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7", + "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==", + "-----END PUBLIC KEY-----", + ] + ) + + # This key is used to sign tokens that shouldn't be accepted by synapse. + # Generated just like jwt_privatekey. + bad_privatekey = "\n".join( + [ + "-----BEGIN RSA PRIVATE KEY-----", + "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv", + "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L", + "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY", + "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I", + "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb", + "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0", + "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m", + "-----END RSA PRIVATE KEY-----", + ] + ) + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.jwt_enabled = True + self.hs.config.jwt_secret = self.jwt_pubkey + self.hs.config.jwt_algorithm = "RS256" + return self.hs + + def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") + if isinstance(result, bytes): + return result.decode("ascii") + return result + + def jwt_login(self, *args): + params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} + channel = self.make_request(b"POST", LOGIN_URL, params) + return channel + + def test_login_jwt_valid(self): + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + def test_login_jwt_invalid_signature(self): + channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: Signature verification failed", + ) + + +AS_USER = "as_user_alice" + + +class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + register.register_servlets, + ] + + def register_as_user(self, username): + self.make_request( + b"POST", + "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), + {"username": username}, + ) + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + + self.service = ApplicationService( + id="unique_identifier", + token="some_token", + hostname="example.com", + sender="@asbot:example.com", + namespaces={ + ApplicationService.NS_USERS: [ + {"regex": r"@as_user.*", "exclusive": False} + ], + ApplicationService.NS_ROOMS: [], + ApplicationService.NS_ALIASES: [], + }, + ) + self.another_service = ApplicationService( + id="another__identifier", + token="another_token", + hostname="example.com", + sender="@as2bot:example.com", + namespaces={ + ApplicationService.NS_USERS: [ + {"regex": r"@as2_user.*", "exclusive": False} + ], + ApplicationService.NS_ROOMS: [], + ApplicationService.NS_ALIASES: [], + }, + ) + + self.hs.get_datastore().services_cache.append(self.service) + self.hs.get_datastore().services_cache.append(self.another_service) + return self.hs + + def test_login_appservice_user(self): + """Test that an appservice user can use /login""" + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_login_appservice_user_bot(self): + """Test that the appservice bot can use /login""" + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": self.service.sender}, + } + channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_login_appservice_wrong_user(self): + """Test that non-as users cannot login with the as token""" + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": "fibble_wibble"}, + } + channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.assertEquals(channel.result["code"], b"403", channel.result) + + def test_login_appservice_wrong_as(self): + """Test that as users cannot login with wrong as token""" + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.another_service.token + ) + + self.assertEquals(channel.result["code"], b"403", channel.result) + + def test_login_appservice_no_token(self): + """Test that users must provide a token when using the appservice + login method + """ + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.result["code"], b"401", channel.result) + + +@skip_unless(HAS_OIDC, "requires OIDC") +class UsernamePickerTestCase(HomeserverTestCase): + """Tests for the username picker flow of SSO login""" + + servlets = [login.register_servlets] + + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + + config["oidc_config"] = {} + config["oidc_config"].update(TEST_OIDC_CONFIG) + config["oidc_config"]["user_mapping_provider"] = { + "config": {"display_name_template": "{{ user.displayname }}"} + } + + # whitelist this client URI so we redirect straight to it rather than + # serving a confirmation page + config["sso"] = {"client_whitelist": ["https://x"]} + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d.update(build_synapse_client_resource_tree(self.hs)) + return d + + def test_username_picker(self): + """Test the happy path of a username picker flow.""" + + # do the start of the login flow + channel = self.helper.auth_via_oidc( + {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL + ) + + # that should redirect to the username picker + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + picker_url = location_headers[0] + self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") + + # ... with a username_mapping_session cookie + cookies: Dict[str, str] = {} + channel.extract_cookies(cookies) + self.assertIn("username_mapping_session", cookies) + session_id = cookies["username_mapping_session"] + + # introspect the sso handler a bit to check that the username mapping session + # looks ok. + username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions + self.assertIn( + session_id, + username_mapping_sessions, + "session id not found in map", + ) + session = username_mapping_sessions[session_id] + self.assertEqual(session.remote_user_id, "tester") + self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) + + # the expiry time should be about 15 minutes away + expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) + self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + + # Now, submit a username to the username picker, which should serve a redirect + # to the completion page + content = urlencode({b"username": b"bobby"}).encode("utf8") + chan = self.make_request( + "POST", + path=picker_url, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", "username_mapping_session=" + session_id), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + assert location_headers + + # send a request to the completion page, which should 302 to the client redirectUrl + chan = self.make_request( + "GET", + path=location_headers[0], + custom_headers=[("Cookie", "username_mapping_session=" + session_id)], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + assert location_headers + + # ensure that the returned location matches the requested redirect URL + path, query = location_headers[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # fish the login token out of the returned redirect uri + login_token = params[2][1] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py new file mode 100644 index 0000000000..3cf5871899 --- /dev/null +++ b/tests/rest/client/test_password_policy.py @@ -0,0 +1,177 @@ +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client import account, login, password_policy, register + +from tests import unittest + + +class PasswordPolicyTestCase(unittest.HomeserverTestCase): + """Tests the password policy feature and its compliance with MSC2000. + + When validating a password, Synapse does the necessary checks in this order: + + 1. Password is long enough + 2. Password contains digit(s) + 3. Password contains symbol(s) + 4. Password contains uppercase letter(s) + 5. Password contains lowercase letter(s) + + For each test below that checks whether a password triggers the right error code, + that test provides a password good enough to pass the previous tests, but not the + one it is currently testing (nor any test that comes afterward). + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + password_policy.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.register_url = "/_matrix/client/r0/register" + self.policy = { + "enabled": True, + "minimum_length": 10, + "require_digit": True, + "require_symbol": True, + "require_lowercase": True, + "require_uppercase": True, + } + + config = self.default_config() + config["password_config"] = { + "policy": self.policy, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def test_get_policy(self): + """Tests if the /password_policy endpoint returns the configured policy.""" + + channel = self.make_request("GET", "/_matrix/client/r0/password_policy") + + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual( + channel.json_body, + { + "m.minimum_length": 10, + "m.require_digit": True, + "m.require_symbol": True, + "m.require_lowercase": True, + "m.require_uppercase": True, + }, + channel.result, + ) + + def test_password_too_short(self): + request_data = json.dumps({"username": "kermit", "password": "shorty"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_TOO_SHORT, + channel.result, + ) + + def test_password_no_digit(self): + request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_DIGIT, + channel.result, + ) + + def test_password_no_symbol(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_SYMBOL, + channel.result, + ) + + def test_password_no_uppercase(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_UPPERCASE, + channel.result, + ) + + def test_password_no_lowercase(self): + request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_LOWERCASE, + channel.result, + ) + + def test_password_compliant(self): + request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + channel = self.make_request("POST", self.register_url, request_data) + + # Getting a 401 here means the password has passed validation and the server has + # responded with a list of registration flows. + self.assertEqual(channel.code, 401, channel.result) + + def test_password_change(self): + """This doesn't test every possible use case, only that hitting /account/password + triggers the password validation code. + """ + compliant_password = "C0mpl!antpassword" + not_compliant_password = "notcompliantpassword" + + user_id = self.register_user("kermit", compliant_password) + tok = self.login("kermit", compliant_password) + + request_data = json.dumps( + { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } + ) + channel = self.make_request( + "POST", + "/_matrix/client/r0/account/password", + request_data, + access_token=tok, + ) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py new file mode 100644 index 0000000000..1d152352d1 --- /dev/null +++ b/tests/rest/client/test_presence.py @@ -0,0 +1,76 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +from twisted.internet import defer + +from synapse.handlers.presence import PresenceHandler +from synapse.rest.client import presence +from synapse.types import UserID + +from tests import unittest + + +class PresenceTestCase(unittest.HomeserverTestCase): + """Tests presence REST API.""" + + user_id = "@sid:red" + + user = UserID.from_string(user_id) + servlets = [presence.register_servlets] + + def make_homeserver(self, reactor, clock): + + presence_handler = Mock(spec=PresenceHandler) + presence_handler.set_state.return_value = defer.succeed(None) + + hs = self.setup_test_homeserver( + "red", + federation_http_client=None, + federation_client=Mock(), + presence_handler=presence_handler, + ) + + return hs + + def test_put_presence(self): + """ + PUT to the status endpoint with use_presence enabled will call + set_state on the presence handler. + """ + self.hs.config.use_presence = True + + body = {"presence": "here", "status_msg": "beep boop"} + channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + + self.assertEqual(channel.code, 200) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) + + @unittest.override_config({"use_presence": False}) + def test_put_presence_disabled(self): + """ + PUT to the status endpoint with use_presence disabled will NOT call + set_state on the presence handler. + """ + + body = {"presence": "here", "status_msg": "beep boop"} + channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + + self.assertEqual(channel.code, 200) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py new file mode 100644 index 0000000000..2860579c2e --- /dev/null +++ b/tests/rest/client/test_profile.py @@ -0,0 +1,270 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests REST events for /profile paths.""" +from synapse.rest import admin +from synapse.rest.client import login, profile, room + +from tests import unittest + + +class ProfileTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + return self.hs + + def prepare(self, reactor, clock, hs): + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + self.other = self.register_user("other", "pass", displayname="Bob") + + def test_get_displayname(self): + res = self._get_displayname() + self.assertEqual(res, "owner") + + def test_set_displayname(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res = self._get_displayname() + self.assertEqual(res, "test") + + def test_set_displayname_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": "test"}, + ) + self.assertEqual(channel.code, 401, channel.result) + + def test_set_displayname_too_long(self): + """Attempts to set a stupid displayname should get a 400""" + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": "test" * 100}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + res = self._get_displayname() + self.assertEqual(res, "owner") + + def test_get_displayname_other(self): + res = self._get_displayname(self.other) + self.assertEquals(res, "Bob") + + def test_set_displayname_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.other,), + content={"displayname": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_get_avatar_url(self): + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_set_avatar_url(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res = self._get_avatar_url() + self.assertEqual(res, "http://my.server/pic.gif") + + def test_set_avatar_url_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + ) + self.assertEqual(channel.code, 401, channel.result) + + def test_set_avatar_url_too_long(self): + """Attempts to set a stupid avatar_url should get a 400""" + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif" * 100}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_get_avatar_url_other(self): + res = self._get_avatar_url(self.other) + self.assertIsNone(res) + + def test_set_avatar_url_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.other,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def _get_displayname(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/displayname" % (name or self.owner,) + ) + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["displayname"] + + def _get_avatar_url(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/avatar_url" % (name or self.owner,) + ) + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body.get("avatar_url") + + +class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, hs): + # User owning the requested profile. + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + self.profile_url = "/profile/%s" % (self.owner) + + # User requesting the profile. + self.requester = self.register_user("requester", "pass") + self.requester_tok = self.login("requester", "pass") + + self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok) + + def test_no_auth(self): + self.try_fetch_profile(401) + + def test_not_in_shared_room(self): + self.ensure_requester_left_room() + + self.try_fetch_profile(403, access_token=self.requester_tok) + + def test_in_shared_room(self): + self.ensure_requester_left_room() + + self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) + + self.try_fetch_profile(200, self.requester_tok) + + def try_fetch_profile(self, expected_code, access_token=None): + self.request_profile(expected_code, access_token=access_token) + + self.request_profile( + expected_code, url_suffix="/displayname", access_token=access_token + ) + + self.request_profile( + expected_code, url_suffix="/avatar_url", access_token=access_token + ) + + def request_profile(self, expected_code, url_suffix="", access_token=None): + channel = self.make_request( + "GET", self.profile_url + url_suffix, access_token=access_token + ) + self.assertEqual(channel.code, expected_code, channel.result) + + def ensure_requester_left_room(self): + try: + self.helper.leave( + room=self.room_id, user=self.requester, tok=self.requester_tok + ) + except AssertionError: + # We don't care whether the leave request didn't return a 200 (e.g. + # if the user isn't already in the room), because we only want to + # make sure the user isn't in the room. + pass + + +class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, hs): + # User requesting the profile. + self.requester = self.register_user("requester", "pass") + self.requester_tok = self.login("requester", "pass") + + def test_can_lookup_own_profile(self): + """Tests that a user can lookup their own profile without having to be in a room + if 'require_auth_for_profile_requests' is set to true in the server's config. + """ + channel = self.make_request( + "GET", "/profile/" + self.requester, access_token=self.requester_tok + ) + self.assertEqual(channel.code, 200, channel.result) + + channel = self.make_request( + "GET", + "/profile/" + self.requester + "/displayname", + access_token=self.requester_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + channel = self.make_request( + "GET", + "/profile/" + self.requester + "/avatar_url", + access_token=self.requester_tok, + ) + self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py new file mode 100644 index 0000000000..d0ce91ccd9 --- /dev/null +++ b/tests/rest/client/test_push_rule_attrs.py @@ -0,0 +1,414 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse +from synapse.api.errors import Codes +from synapse.rest.client import login, push_rule, room + +from tests.unittest import HomeserverTestCase + + +class PushRuleAttributesTestCase(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + push_rule.register_servlets, + ] + hijack_auth = False + + def test_enabled_on_creation(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is enabled upon creation, even though a rule with that + ruleId existed previously and was disabled. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_on_recreation(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is enabled upon creation, even if a rule with that + ruleId existed previously and was disabled. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # disable the rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": False}, + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # check rule disabled + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], False) + + # DELETE the rule + channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.assertEqual(channel.code, 200) + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_disable(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is disabled and enabled when we ask for it. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # disable the rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": False}, + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # check rule disabled + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], False) + + # re-enable the rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": True}, + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # check rule enabled + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_404_when_get_non_existent(self): + """ + Tests that `enabled` gives 404 when the rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # check 404 for never-heard-of rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + + # DELETE the rule + channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.assertEqual(channel.code, 200) + + # check 404 for deleted rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_get_non_existent_server_rule(self): + """ + Tests that `enabled` gives 404 when the server-default rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # check 404 for never-heard-of rule + channel = self.make_request( + "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_put_non_existent_rule(self): + """ + Tests that `enabled` gives 404 when we put to a rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": True}, + access_token=token, + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_put_non_existent_server_rule(self): + """ + Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/.m.muahahah/enabled", + {"enabled": True}, + access_token=token, + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_get(self): + """ + Tests that `actions` gives you what you expect on a fresh rule. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # GET actions for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/actions", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] + ) + + def test_actions_put(self): + """ + Tests that PUT on actions updates the value you'd get from GET. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # change the rule actions + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # GET actions for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/actions", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["actions"], ["dont_notify"]) + + def test_actions_404_when_get_non_existent(self): + """ + Tests that `actions` gives 404 when the rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # check 404 for never-heard-of rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # DELETE the rule + channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.assertEqual(channel.code, 200) + + # check 404 for deleted rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_get_non_existent_server_rule(self): + """ + Tests that `actions` gives 404 when the server-default rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # check 404 for never-heard-of rule + channel = self.make_request( + "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_put_non_existent_rule(self): + """ + Tests that `actions` gives 404 when putting to a rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_put_non_existent_server_rule(self): + """ + Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/.m.muahahah/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py new file mode 100644 index 0000000000..fecda037a5 --- /dev/null +++ b/tests/rest/client/test_register.py @@ -0,0 +1,746 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import json +import os + +import pkg_resources + +import synapse.rest.admin +from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType +from synapse.api.errors import Codes +from synapse.appservice import ApplicationService +from synapse.rest.client import account, account_validity, login, logout, register, sync + +from tests import unittest +from tests.unittest import override_config + + +class RegisterRestServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + register.register_servlets, + synapse.rest.admin.register_servlets, + ] + url = b"/_matrix/client/r0/register" + + def default_config(self): + config = super().default_config() + config["allow_guest_access"] = True + return config + + def test_POST_appservice_registration_valid(self): + user_id = "@as_user_kermit:test" + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + self.hs.config.server_name, + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + self.hs.get_datastore().services_cache.append(appservice) + request_data = json.dumps( + {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} + ) + + channel = self.make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + + self.assertEquals(channel.result["code"], b"200", channel.result) + det_data = {"user_id": user_id, "home_server": self.hs.hostname} + self.assertDictContainsSubset(det_data, channel.json_body) + + def test_POST_appservice_registration_no_type(self): + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + self.hs.config.server_name, + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + self.hs.get_datastore().services_cache.append(appservice) + request_data = json.dumps({"username": "as_user_kermit"}) + + channel = self.make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + + self.assertEquals(channel.result["code"], b"400", channel.result) + + def test_POST_appservice_registration_invalid(self): + self.appservice = None # no application service exists + request_data = json.dumps( + {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} + ) + channel = self.make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + + self.assertEquals(channel.result["code"], b"401", channel.result) + + def test_POST_bad_password(self): + request_data = json.dumps({"username": "kermit", "password": 666}) + channel = self.make_request(b"POST", self.url, request_data) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals(channel.json_body["error"], "Invalid password") + + def test_POST_bad_username(self): + request_data = json.dumps({"username": 777, "password": "monkey"}) + channel = self.make_request(b"POST", self.url, request_data) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals(channel.json_body["error"], "Invalid username") + + def test_POST_user_valid(self): + user_id = "@kermit:test" + device_id = "frogfone" + params = { + "username": "kermit", + "password": "monkey", + "device_id": device_id, + "auth": {"type": LoginType.DUMMY}, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + + det_data = { + "user_id": user_id, + "home_server": self.hs.hostname, + "device_id": device_id, + } + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, channel.json_body) + + @override_config({"enable_registration": False}) + def test_POST_disabled_registration(self): + request_data = json.dumps({"username": "kermit", "password": "monkey"}) + self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) + + channel = self.make_request(b"POST", self.url, request_data) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals(channel.json_body["error"], "Registration has been disabled") + self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") + + def test_POST_guest_registration(self): + self.hs.config.macaroon_secret_key = "test" + self.hs.config.allow_guest_access = True + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, channel.json_body) + + def test_POST_disabled_guest_registration(self): + self.hs.config.allow_guest_access = False + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals(channel.json_body["error"], "Guest access is disabled") + + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) + def test_POST_ratelimiting_guest(self): + for i in range(0, 6): + url = self.url + b"?kind=guest" + channel = self.make_request(b"POST", url, b"{}") + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) + def test_POST_ratelimiting(self): + for i in range(0, 6): + params = { + "username": "kermit" + str(i), + "password": "monkey", + "device_id": "frogfone", + "auth": {"type": LoginType.DUMMY}, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_advertised_flows(self): + channel = self.make_request(b"POST", self.url, b"{}") + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we only expect the dummy flow + self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows)) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "enable_registration_captcha": True, + "user_consent": { + "version": "1", + "template_dir": "/", + "require_at_registration": True, + }, + "account_threepid_delegates": { + "email": "https://id_server", + "msisdn": "https://id_server", + }, + } + ) + def test_advertised_flows_captcha_and_terms_and_3pids(self): + channel = self.make_request(b"POST", self.url, b"{}") + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + self.assertCountEqual( + [ + ["m.login.recaptcha", "m.login.terms", "m.login.dummy"], + ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"], + ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"], + [ + "m.login.recaptcha", + "m.login.terms", + "m.login.msisdn", + "m.login.email.identity", + ], + ], + (f["stages"] for f in flows), + ) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "registrations_require_3pid": ["email"], + "disable_msisdn_registration": True, + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_advertised_flows_no_msisdn_email_required(self): + channel = self.make_request(b"POST", self.url, b"{}") + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we expect all four combinations of 3pid + self.assertCountEqual( + [["m.login.email.identity"]], (f["stages"] for f in flows) + ) + + @unittest.override_config( + { + "request_token_inhibit_3pid_errors": True, + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_request_token_existing_email_inhibit_error(self): + """Test that requesting a token via this endpoint doesn't leak existing + associations if configured that way. + """ + user_id = self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.hs.get_datastore().user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.assertEquals(200, channel.code, channel.result) + + self.assertIsNotNone(channel.json_body.get("sid")) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_reject_invalid_email(self): + """Check that bad emails are rejected""" + + # Test for email with multiple @ + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + # Check error to ensure that we're not erroring due to a bug in the test. + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + # Test for email with no @ + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": "email", "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + # Test for super long email + email = "a@" + "a" * 1000 + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + +class AccountValidityTestCase(unittest.HomeserverTestCase): + + servlets = [ + register.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + sync.register_servlets, + logout.register_servlets, + account_validity.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + # Test for account expiring after a week. + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + } + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def test_validity_period(self): + self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + channel = self.make_request(b"GET", "/sync", access_token=tok) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) + + channel = self.make_request(b"GET", "/sync", access_token=tok) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result + ) + + def test_manual_renewal(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) + + # If we register the admin user at the beginning of the test, it will + # expire at the same time as the normal user and the renewal request + # will be denied. + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_synapse/admin/v1/account_validity/validity" + params = {"user_id": user_id} + request_data = json.dumps(params) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + channel = self.make_request(b"GET", "/sync", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_manual_expire(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_synapse/admin/v1/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + channel = self.make_request(b"GET", "/sync", access_token=tok) + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result + ) + + def test_logging_out_expired_user(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_synapse/admin/v1/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Try to log the user out + channel = self.make_request(b"POST", "/logout", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Log the user in again (allowed for expired accounts) + tok = self.login("kermit", "monkey") + + # Try to log out all of the user's sessions + channel = self.make_request(b"POST", "/logout/all", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + +class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): + + servlets = [ + register.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + sync.register_servlets, + account_validity.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Test for account expiring after a week and renewal emails being sent 2 + # days before expiry. + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + "renew_at": 172800000, # Time in ms for 2 days + "renew_by_email_enabled": True, + "renew_email_subject": "Renew your account", + "account_renewed_html_path": "account_renewed.html", + "invalid_token_html_path": "invalid_token.html", + } + + # Email config. + + config["email"] = { + "enable_notifs": True, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "expiry_template_html": "notice_expiry.html", + "expiry_template_text": "notice_expiry.txt", + "notif_template_html": "notif_mail.html", + "notif_template_text": "notif_mail.txt", + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "aaa" + + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail(*args, **kwargs): + self.email_attempts.append((args, kwargs)) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail + + self.store = self.hs.get_datastore() + + return self.hs + + def test_renewal_email(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + + # Move 5 days forward. This should trigger a renewal email to be sent. + self.reactor.advance(datetime.timedelta(days=5).total_seconds()) + self.assertEqual(len(self.email_attempts), 1) + + # Retrieving the URL from the email is too much pain for now, so we + # retrieve the token from the DB. + renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) + url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect on a successful renewal. + expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render( + expiration_ts=expiration_ts + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + # Move 1 day forward. Try to renew with the same token again. + url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect when reusing a + # token. The account expiration date should not have changed. + expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + # Move 3 days forward. If the renewal failed, every authed request with + # our access token should be denied from now, otherwise they should + # succeed. + self.reactor.advance(datetime.timedelta(days=3).total_seconds()) + channel = self.make_request(b"GET", "/sync", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_renewal_invalid_token(self): + # Hit the renewal endpoint with an invalid token and check that it behaves as + # expected, i.e. that it responds with 404 Not Found and the correct HTML. + url = "/_matrix/client/unstable/account_validity/renew?token=123" + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"404", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect when using an + # invalid/unknown token. + expected_html = ( + self.hs.config.account_validity.account_validity_invalid_token_template.render() + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + def test_manual_email_send(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + channel = self.make_request( + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", + access_token=tok, + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.assertEqual(len(self.email_attempts), 1) + + def test_deactivated_user(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + + request_data = json.dumps( + { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + } + ) + channel = self.make_request( + "POST", "account/deactivate", request_data, access_token=tok + ) + self.assertEqual(channel.code, 200) + + self.reactor.advance(datetime.timedelta(days=8).total_seconds()) + + self.assertEqual(len(self.email_attempts), 0) + + def create_user(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + # We need to manually add an email address otherwise the handler will do + # nothing. + now = self.hs.get_clock().time_msec() + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) + return user_id, tok + + def test_manual_email_send_expired_account(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + # We need to manually add an email address otherwise the handler will do + # nothing. + now = self.hs.get_clock().time_msec() + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) + + # Make the account expire. + self.reactor.advance(datetime.timedelta(days=8).total_seconds()) + + # Ignore all emails sent by the automatic background task and only focus on the + # ones sent manually. + self.email_attempts = [] + + # Test that we're still able to manually trigger a mail to be sent. + channel = self.make_request( + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", + access_token=tok, + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.assertEqual(len(self.email_attempts), 1) + + +class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): + + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] + + def make_homeserver(self, reactor, clock): + self.validity_period = 10 + self.max_delta = self.validity_period * 10.0 / 100.0 + + config = self.default_config() + + config["enable_registration"] = True + config["account_validity"] = {"enabled": False} + + self.hs = self.setup_test_homeserver(config=config) + + # We need to set these directly, instead of in the homeserver config dict above. + # This is due to account validity-related config options not being read by + # Synapse when account_validity.enabled is False. + self.hs.get_datastore()._account_validity_period = self.validity_period + self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta + + self.store = self.hs.get_datastore() + + return self.hs + + def test_background_job(self): + """ + Tests the same thing as test_background_job, except that it sets the + startup_job_max_delta parameter and checks that the expiration date is within the + allowed range. + """ + user_id = self.register_user("kermit_delta", "user") + + self.hs.config.account_validity.startup_job_max_delta = self.max_delta + + now_ms = self.hs.get_clock().time_msec() + self.get_success(self.store._set_expiration_date_when_missing()) + + res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + + self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) + self.assertLessEqual(res, now_ms + self.validity_period) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py new file mode 100644 index 0000000000..02b5e9a8d0 --- /dev/null +++ b/tests/rest/client/test_relations.py @@ -0,0 +1,724 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import json +import urllib +from typing import Optional + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.rest import admin +from synapse.rest.client import login, register, relations, room + +from tests import unittest + + +class RelationsTestCase(unittest.HomeserverTestCase): + servlets = [ + relations.register_servlets, + room.register_servlets, + login.register_servlets, + register.register_servlets, + admin.register_servlets_for_client_rest_resource, + ] + hijack_auth = False + + def make_homeserver(self, reactor, clock): + # We need to enable msc1849 support for aggregations + config = self.default_config() + config["experimental_msc1849_support_enabled"] = True + + # We enable frozen dicts as relations/edits change event contents, so we + # want to test that we don't modify the events in the caches. + config["use_frozen_dicts"] = True + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + self.user_id, self.user_token = self._create_user("alice") + self.user2_id, self.user2_token = self._create_user("bob") + + self.room = self.helper.create_room_as(self.user_id, tok=self.user_token) + self.helper.join(self.room, user=self.user2_id, tok=self.user2_token) + res = self.helper.send(self.room, body="Hi!", tok=self.user_token) + self.parent_id = res["event_id"] + + def test_send_relation(self): + """Tests that sending a relation using the new /send_relation works + creates the right shape of event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + self.assertEquals(200, channel.code, channel.json_body) + + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, event_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assert_dict( + { + "type": "m.reaction", + "sender": self.user_id, + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "key": "👍", + "rel_type": RelationTypes.ANNOTATION, + } + }, + }, + channel.json_body, + ) + + def test_deny_membership(self): + """Test that we deny relations on membership events""" + channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) + self.assertEquals(400, channel.code, channel.json_body) + + def test_deny_double_react(self): + """Test that we deny relations on membership events""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(400, channel.code, channel.json_body) + + def test_basic_paginate_relations(self): + """Tests that calling pagination API correctly the latest relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + annotation_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the full + # relation event we sent above. + self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, + channel.json_body["chunk"][0], + ) + + # We also expect to get the original event (the id of which is self.parent_id) + self.assertEquals( + channel.json_body["original_event"]["event_id"], self.parent_id + ) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + self.assertIsInstance( + channel.json_body.get("next_batch"), str, channel.json_body + ) + + def test_repeated_paginate_relations(self): + """Test that if we paginate using a limit and tokens then we get the + expected events. + """ + + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + prev_token = None + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation_pagination_groups(self): + """Test that we can paginate annotation groups correctly.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} + for key in itertools.chain.from_iterable( + itertools.repeat(key, num) for key, num in sent_groups.items() + ): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key=key, + access_token=access_tokens[idx], + ) + self.assertEquals(200, channel.code, channel.json_body) + + idx += 1 + idx %= len(access_tokens) + + prev_token = None + found_groups = {} + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + for groups in channel.json_body["chunk"]: + # We only expect reactions + self.assertEqual(groups["type"], "m.reaction", channel.json_body) + + # We should only see each key once + self.assertNotIn(groups["key"], found_groups, channel.json_body) + + found_groups[groups["key"]] = groups["count"] + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEquals(sent_groups, found_groups) + + def test_aggregation_pagination_within_group(self): + """Test that we can paginate within an annotation group.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key="👍", + access_token=access_tokens[idx], + ) + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + idx += 1 + + # Also send a different type of reaction so that we test we don't see it + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self.assertEquals(200, channel.code, channel.json_body) + + prev_token = None + found_event_ids = [] + encoded_key = urllib.parse.quote_plus("👍".encode()) + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s" + "/aggregations/%s/%s/m.reaction/%s?limit=1%s" + % ( + self.room, + self.parent_id, + RelationTypes.ANNOTATION, + encoded_key, + from_token, + ), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation(self): + """Test that annotations get correctly aggregated.""" + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + ) + + def test_aggregation_redactions(self): + """Test that annotations get correctly aggregated after a redaction.""" + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + to_redact_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Now lets redact one of the 'a' reactions + channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), + access_token=self.user_token, + content={}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + ) + + def test_aggregation_must_be_annotation(self): + """Test that aggregations must be annotations.""" + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" + % (self.room, self.parent_id, RelationTypes.REPLACE), + access_token=self.user_token, + ) + self.assertEquals(400, channel.code, channel.json_body) + + def test_aggregation_get_event(self): + """Test that annotations and references get correctly bundled when + getting the parent event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_2 = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + RelationTypes.REFERENCE: { + "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] + }, + }, + ) + + def test_edit(self): + """Test that a simple edit works.""" + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_multi_edit(self): + """Test that multiple edits, including attempts by people who + shouldn't be allowed, are correctly handled. + """ + + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "First edit"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message.WRONG_TYPE", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_edit_reply(self): + """Test that editing a reply works.""" + + # Create a reply to edit. + channel = self._send_relation( + RelationTypes.REFERENCE, + "m.room.message", + content={"msgtype": "m.text", "body": "A reply!"}, + ) + self.assertEquals(200, channel.code, channel.json_body) + reply = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=reply, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, reply), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to see the new body in the dict, as well as the reference + # metadata sill intact. + self.assertDictContainsSubset(new_body, channel.json_body["content"]) + self.assertDictContainsSubset( + { + "m.relates_to": { + "event_id": self.parent_id, + "key": None, + "rel_type": "m.reference", + } + }, + channel.json_body["content"], + ) + + # We expect that the edit relation appears in the unsigned relations + # section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_relations_redaction_redacts_edits(self): + """Test that edits of an event are redacted when the original event + is redacted. + """ + # Send a new event + res = self.helper.send(self.room, body="Heyo!", tok=self.user_token) + original_event_id = res["event_id"] + + # Add a relation + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + parent_id=original_event_id, + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "First edit"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Check the relation is returned + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" + % (self.room, original_event_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertIn("chunk", channel.json_body) + self.assertEquals(len(channel.json_body["chunk"]), 1) + + # Redact the original event + channel = self.make_request( + "PUT", + "/rooms/%s/redact/%s/%s" + % (self.room, original_event_id, "test_relations_redaction_redacts_edits"), + access_token=self.user_token, + content="{}", + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Try to check for remaining m.replace relations + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" + % (self.room, original_event_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Check that no relations are returned + self.assertIn("chunk", channel.json_body) + self.assertEquals(channel.json_body["chunk"], []) + + def test_aggregations_redaction_prevents_access_to_aggregations(self): + """Test that annotations of an event are redacted when the original event + is redacted. + """ + # Send a new event + res = self.helper.send(self.room, body="Hello!", tok=self.user_token) + original_event_id = res["event_id"] + + # Add a relation + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Redact the original + channel = self.make_request( + "PUT", + "/rooms/%s/redact/%s/%s" + % ( + self.room, + original_event_id, + "test_aggregations_redaction_prevents_access_to_aggregations", + ), + access_token=self.user_token, + content="{}", + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Check that aggregations returns zero + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction" + % (self.room, original_event_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertIn("chunk", channel.json_body) + self.assertEquals(channel.json_body["chunk"], []) + + def _send_relation( + self, + relation_type, + event_type, + key=None, + content: Optional[dict] = None, + access_token=None, + parent_id=None, + ): + """Helper function to send a relation pointing at `self.parent_id` + + Args: + relation_type (str): One of `RelationTypes` + event_type (str): The type of the event to create + parent_id (str): The event_id this relation relates to. If None, then self.parent_id + key (str|None): The aggregation key used for m.annotation relation + type. + content(dict|None): The content of the created event. + access_token (str|None): The access token used to send the relation, + defaults to `self.user_token` + + Returns: + FakeChannel + """ + if not access_token: + access_token = self.user_token + + query = "" + if key: + query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8")) + + original_id = parent_id if parent_id else self.parent_id + + channel = self.make_request( + "POST", + "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" + % (self.room, original_id, relation_type, event_type, query), + json.dumps(content or {}).encode("utf-8"), + access_token=access_token, + ) + return channel + + def _create_user(self, localpart): + user_id = self.register_user(localpart, "abc123") + access_token = self.login(localpart, "abc123") + + return user_id, access_token diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py new file mode 100644 index 0000000000..ee6b0b9ebf --- /dev/null +++ b/tests/rest/client/test_report_event.py @@ -0,0 +1,82 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import synapse.rest.admin +from synapse.rest.client import login, report_event, room + +from tests import unittest + + +class ReportEventTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + report_event.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) + resp = self.helper.send(self.room_id, tok=self.admin_user_tok) + self.event_id = resp["event_id"] + self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" + + def test_reason_str_and_score_int(self): + data = {"reason": "this makes me sad", "score": -100} + self._assert_status(200, data) + + def test_no_reason(self): + data = {"score": 0} + self._assert_status(200, data) + + def test_no_score(self): + data = {"reason": "this makes me sad"} + self._assert_status(200, data) + + def test_no_reason_and_no_score(self): + data = {} + self._assert_status(200, data) + + def test_reason_int_and_score_str(self): + data = {"reason": 10, "score": "string"} + self._assert_status(400, data) + + def test_reason_zero_and_score_blank(self): + data = {"reason": 0, "score": ""} + self._assert_status(400, data) + + def test_reason_and_score_null(self): + data = {"reason": None, "score": None} + self._assert_status(400, data) + + def _assert_status(self, response_status, data): + channel = self.make_request( + "POST", + self.report_path, + json.dumps(data), + access_token=self.other_user_tok, + ) + self.assertEqual( + response_status, int(channel.result["code"]), msg=channel.result["body"] + ) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py new file mode 100644 index 0000000000..0c9cbb9aff --- /dev/null +++ b/tests/rest/client/test_rooms.py @@ -0,0 +1,2150 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests REST events for /rooms paths.""" + +import json +from typing import Iterable +from unittest.mock import Mock, call +from urllib import parse as urlparse + +from twisted.internet import defer + +import synapse.rest.admin +from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.errors import HttpResponseException +from synapse.handlers.pagination import PurgeStatus +from synapse.rest import admin +from synapse.rest.client import account, directory, login, profile, room +from synapse.types import JsonDict, RoomAlias, UserID, create_requester +from synapse.util.stringutils import random_string + +from tests import unittest +from tests.test_utils import make_awaitable + +PATH_PREFIX = b"/_matrix/client/api/v1" + + +class RoomBase(unittest.HomeserverTestCase): + rmcreator_id = None + + servlets = [room.register_servlets, room.register_deprecated_servlets] + + def make_homeserver(self, reactor, clock): + + self.hs = self.setup_test_homeserver( + "red", + federation_http_client=None, + federation_client=Mock(), + ) + + self.hs.get_federation_handler = Mock() + self.hs.get_federation_handler.return_value.maybe_backfill = Mock( + return_value=make_awaitable(None) + ) + + async def _insert_client_ip(*args, **kwargs): + return None + + self.hs.get_datastore().insert_client_ip = _insert_client_ip + + return self.hs + + +class RoomPermissionsTestCase(RoomBase): + """Tests room permissions.""" + + user_id = "@sid1:red" + rmcreator_id = "@notme:red" + + def prepare(self, reactor, clock, hs): + + self.helper.auth_user_id = self.rmcreator_id + # create some rooms under the name rmcreator_id + self.uncreated_rmid = "!aa:test" + self.created_rmid = self.helper.create_room_as( + self.rmcreator_id, is_public=False + ) + self.created_public_rmid = self.helper.create_room_as( + self.rmcreator_id, is_public=True + ) + + # send a message in one of the rooms + self.created_rmid_msg_path = ( + "rooms/%s/send/m.room.message/a1" % (self.created_rmid) + ).encode("ascii") + channel = self.make_request( + "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' + ) + self.assertEquals(200, channel.code, channel.result) + + # set topic for public room + channel = self.make_request( + "PUT", + ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), + b'{"topic":"Public Room Topic"}', + ) + self.assertEquals(200, channel.code, channel.result) + + # auth as user_id now + self.helper.auth_user_id = self.user_id + + def test_can_do_action(self): + msg_content = b'{"msgtype":"m.text","body":"hello"}' + + seq = iter(range(100)) + + def send_msg_path(): + return "/rooms/%s/send/m.room.message/mid%s" % ( + self.created_rmid, + str(next(seq)), + ) + + # send message in uncreated room, expect 403 + channel = self.make_request( + "PUT", + "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), + msg_content, + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # send message in created room not joined (no state), expect 403 + channel = self.make_request("PUT", send_msg_path(), msg_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # send message in created room and invited, expect 403 + self.helper.invite( + room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id + ) + channel = self.make_request("PUT", send_msg_path(), msg_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # send message in created room and joined, expect 200 + self.helper.join(room=self.created_rmid, user=self.user_id) + channel = self.make_request("PUT", send_msg_path(), msg_content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # send message in created room and left, expect 403 + self.helper.leave(room=self.created_rmid, user=self.user_id) + channel = self.make_request("PUT", send_msg_path(), msg_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def test_topic_perms(self): + topic_content = b'{"topic":"My Topic Name"}' + topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid + + # set/get topic in uncreated room, expect 403 + channel = self.make_request( + "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + channel = self.make_request( + "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # set/get topic in created PRIVATE room not joined, expect 403 + channel = self.make_request("PUT", topic_path, topic_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + channel = self.make_request("GET", topic_path) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # set topic in created PRIVATE room and invited, expect 403 + self.helper.invite( + room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id + ) + channel = self.make_request("PUT", topic_path, topic_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # get topic in created PRIVATE room and invited, expect 403 + channel = self.make_request("GET", topic_path) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # set/get topic in created PRIVATE room and joined, expect 200 + self.helper.join(room=self.created_rmid, user=self.user_id) + + # Only room ops can set topic by default + self.helper.auth_user_id = self.rmcreator_id + channel = self.make_request("PUT", topic_path, topic_content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.helper.auth_user_id = self.user_id + + channel = self.make_request("GET", topic_path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) + + # set/get topic in created PRIVATE room and left, expect 403 + self.helper.leave(room=self.created_rmid, user=self.user_id) + channel = self.make_request("PUT", topic_path, topic_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + channel = self.make_request("GET", topic_path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # get topic in PUBLIC room, not joined, expect 403 + channel = self.make_request( + "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # set topic in PUBLIC room, not joined, expect 403 + channel = self.make_request( + "PUT", + "/rooms/%s/state/m.room.topic" % self.created_public_rmid, + topic_content, + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def _test_get_membership( + self, room=None, members: Iterable = frozenset(), expect_code=None + ): + for member in members: + path = "/rooms/%s/state/m.room.member/%s" % (room, member) + channel = self.make_request("GET", path) + self.assertEquals(expect_code, channel.code) + + def test_membership_basic_room_perms(self): + # === room does not exist === + room = self.uncreated_rmid + # get membership of self, get membership of other, uncreated room + # expect all 403s + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) + + # trying to invite people to this room should 403 + self.helper.invite( + room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 + ) + + # set [invite/join/left] of self, set [invite/join/left] of other, + # expect all 404s because room doesn't exist on any server + for usr in [self.user_id, self.rmcreator_id]: + self.helper.join(room=room, user=usr, expect_code=404) + self.helper.leave(room=room, user=usr, expect_code=404) + + def test_membership_private_room_perms(self): + room = self.created_rmid + # get membership of self, get membership of other, private room + invite + # expect all 403s + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) + + # get membership of self, get membership of other, private room + joined + # expect all 200s + self.helper.join(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) + + # get membership of self, get membership of other, private room + left + # expect all 200s + self.helper.leave(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) + + def test_membership_public_room_perms(self): + room = self.created_public_rmid + # get membership of self, get membership of other, public room + invite + # expect 403 + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) + + # get membership of self, get membership of other, public room + joined + # expect all 200s + self.helper.join(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) + + # get membership of self, get membership of other, public room + left + # expect 200. + self.helper.leave(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) + + def test_invited_permissions(self): + room = self.created_rmid + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + + # set [invite/join/left] of other user, expect 403s + self.helper.invite( + room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 + ) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.JOIN, + expect_code=403, + ) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.LEAVE, + expect_code=403, + ) + + def test_joined_permissions(self): + room = self.created_rmid + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.join(room=room, user=self.user_id) + + # set invited of self, expect 403 + self.helper.invite( + room=room, src=self.user_id, targ=self.user_id, expect_code=403 + ) + + # set joined of self, expect 200 (NOOP) + self.helper.join(room=room, user=self.user_id) + + other = "@burgundy:red" + # set invited of other, expect 200 + self.helper.invite(room=room, src=self.user_id, targ=other, expect_code=200) + + # set joined of other, expect 403 + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.JOIN, + expect_code=403, + ) + + # set left of other, expect 403 + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.LEAVE, + expect_code=403, + ) + + # set left of self, expect 200 + self.helper.leave(room=room, user=self.user_id) + + def test_leave_permissions(self): + room = self.created_rmid + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.join(room=room, user=self.user_id) + self.helper.leave(room=room, user=self.user_id) + + # set [invite/join/left] of self, set [invite/join/left] of other, + # expect all 403s + for usr in [self.user_id, self.rmcreator_id]: + self.helper.change_membership( + room=room, + src=self.user_id, + targ=usr, + membership=Membership.INVITE, + expect_code=403, + ) + + self.helper.change_membership( + room=room, + src=self.user_id, + targ=usr, + membership=Membership.JOIN, + expect_code=403, + ) + + # It is always valid to LEAVE if you've already left (currently.) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.LEAVE, + expect_code=403, + ) + + +class RoomsMemberListTestCase(RoomBase): + """Tests /rooms/$room_id/members/list REST events.""" + + user_id = "@sid1:red" + + def test_get_member_list(self): + room_id = self.helper.create_room_as(self.user_id) + channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + def test_get_member_list_no_room(self): + channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def test_get_member_list_no_permission(self): + room_id = self.helper.create_room_as("@some_other_guy:red") + channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def test_get_member_list_mixed_memberships(self): + room_creator = "@some_other_guy:red" + room_id = self.helper.create_room_as(room_creator) + room_path = "/rooms/%s/members" % room_id + self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) + # can't see list if you're just invited. + channel = self.make_request("GET", room_path) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + self.helper.join(room=room_id, user=self.user_id) + # can see list now joined + channel = self.make_request("GET", room_path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + self.helper.leave(room=room_id, user=self.user_id) + # can see old list once left + channel = self.make_request("GET", room_path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + +class RoomsCreateTestCase(RoomBase): + """Tests /rooms and /rooms/$room_id REST events.""" + + user_id = "@sid1:red" + + def test_post_room_no_keys(self): + # POST with no config keys, expect new room id + channel = self.make_request("POST", "/createRoom", "{}") + + self.assertEquals(200, channel.code, channel.result) + self.assertTrue("room_id" in channel.json_body) + + def test_post_room_visibility_key(self): + # POST with visibility config key, expect new room id + channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') + self.assertEquals(200, channel.code) + self.assertTrue("room_id" in channel.json_body) + + def test_post_room_custom_key(self): + # POST with custom config keys, expect new room id + channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') + self.assertEquals(200, channel.code) + self.assertTrue("room_id" in channel.json_body) + + def test_post_room_known_and_unknown_keys(self): + # POST with custom + known config keys, expect new room id + channel = self.make_request( + "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' + ) + self.assertEquals(200, channel.code) + self.assertTrue("room_id" in channel.json_body) + + def test_post_room_invalid_content(self): + # POST with invalid content / paths, expect 400 + channel = self.make_request("POST", "/createRoom", b'{"visibili') + self.assertEquals(400, channel.code) + + channel = self.make_request("POST", "/createRoom", b'["hello"]') + self.assertEquals(400, channel.code) + + def test_post_room_invitees_invalid_mxid(self): + # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 + # Note the trailing space in the MXID here! + channel = self.make_request( + "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' + ) + self.assertEquals(400, channel.code) + + @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) + def test_post_room_invitees_ratelimit(self): + """Test that invites sent when creating a room are ratelimited by a RateLimiter, + which ratelimits them correctly, including by not limiting when the requester is + exempt from ratelimiting. + """ + + # Build the request's content. We use local MXIDs because invites over federation + # are more difficult to mock. + content = json.dumps( + { + "invite": [ + "@alice1:red", + "@alice2:red", + "@alice3:red", + "@alice4:red", + ] + } + ).encode("utf8") + + # Test that the invites are correctly ratelimited. + channel = self.make_request("POST", "/createRoom", content) + self.assertEqual(400, channel.code) + self.assertEqual( + "Cannot invite so many users at once", + channel.json_body["error"], + ) + + # Add the current user to the ratelimit overrides, allowing them no ratelimiting. + self.get_success( + self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0) + ) + + # Test that the invites aren't ratelimited anymore. + channel = self.make_request("POST", "/createRoom", content) + self.assertEqual(200, channel.code) + + +class RoomTopicTestCase(RoomBase): + """Tests /rooms/$room_id/topic REST events.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + # create the room + self.room_id = self.helper.create_room_as(self.user_id) + self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) + + def test_invalid_puts(self): + # missing keys or invalid json + channel = self.make_request("PUT", self.path, "{}") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", self.path, '{"_name":"bo"}') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", self.path, '{"nao') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request( + "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' + ) + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", self.path, "text only") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", self.path, "") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + # valid key, wrong type + content = '{"topic":["Topic name"]}' + channel = self.make_request("PUT", self.path, content) + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + def test_rooms_topic(self): + # nothing should be there + channel = self.make_request("GET", self.path) + self.assertEquals(404, channel.code, msg=channel.result["body"]) + + # valid put + content = '{"topic":"Topic name"}' + channel = self.make_request("PUT", self.path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # valid get + channel = self.make_request("GET", self.path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assert_dict(json.loads(content), channel.json_body) + + def test_rooms_topic_with_extra_keys(self): + # valid put with extra keys + content = '{"topic":"Seasons","subtopic":"Summer"}' + channel = self.make_request("PUT", self.path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # valid get + channel = self.make_request("GET", self.path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assert_dict(json.loads(content), channel.json_body) + + +class RoomMemberStateTestCase(RoomBase): + """Tests /rooms/$room_id/members/$user_id/state REST events.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_invalid_puts(self): + path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) + # missing keys or invalid json + channel = self.make_request("PUT", path, "{}") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, '{"_name":"bo"}') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, '{"nao') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, "text only") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, "") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + # valid keys, wrong types + content = '{"membership":["%s","%s","%s"]}' % ( + Membership.INVITE, + Membership.JOIN, + Membership.LEAVE, + ) + channel = self.make_request("PUT", path, content.encode("ascii")) + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + def test_rooms_members_self(self): + path = "/rooms/%s/state/m.room.member/%s" % ( + urlparse.quote(self.room_id), + self.user_id, + ) + + # valid join message (NOOP since we made the room) + content = '{"membership":"%s"}' % Membership.JOIN + channel = self.make_request("PUT", path, content.encode("ascii")) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + channel = self.make_request("GET", path, None) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + expected_response = {"membership": Membership.JOIN} + self.assertEquals(expected_response, channel.json_body) + + def test_rooms_members_other(self): + self.other_id = "@zzsid1:red" + path = "/rooms/%s/state/m.room.member/%s" % ( + urlparse.quote(self.room_id), + self.other_id, + ) + + # valid invite message + content = '{"membership":"%s"}' % Membership.INVITE + channel = self.make_request("PUT", path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + channel = self.make_request("GET", path, None) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEquals(json.loads(content), channel.json_body) + + def test_rooms_members_other_custom_keys(self): + self.other_id = "@zzsid1:red" + path = "/rooms/%s/state/m.room.member/%s" % ( + urlparse.quote(self.room_id), + self.other_id, + ) + + # valid invite message with custom key + content = '{"membership":"%s","invite_text":"%s"}' % ( + Membership.INVITE, + "Join us!", + ) + channel = self.make_request("PUT", path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + channel = self.make_request("GET", path, None) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEquals(json.loads(content), channel.json_body) + + +class RoomInviteRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + admin.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_rooms_ratelimit(self): + """Tests that invites in a room are actually rate-limited.""" + room_id = self.helper.create_room_as(self.user_id) + + for i in range(3): + self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,)) + + self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429) + + @unittest.override_config( + {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_users_ratelimit(self): + """Tests that invites to a specific user are actually rate-limited.""" + + for _ in range(3): + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red") + + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429) + + +class RoomJoinRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + admin.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit(self): + """Tests that local joins are actually rate-limited.""" + for _ in range(3): + self.helper.create_room_as(self.user_id) + + self.helper.create_room_as(self.user_id, expect_code=429) + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit_profile_change(self): + """Tests that sending a profile update into all of the user's joined rooms isn't + rate-limited by the rate-limiter on joins.""" + + # Create and join as many rooms as the rate-limiting config allows in a second. + room_ids = [ + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + ] + # Let some time for the rate-limiter to forget about our multi-join. + self.reactor.advance(2) + # Add one to make sure we're joined to more rooms than the config allows us to + # join in a second. + room_ids.append(self.helper.create_room_as(self.user_id)) + + # Create a profile for the user, since it hasn't been done on registration. + store = self.hs.get_datastore() + self.get_success( + store.create_profile(UserID.from_string(self.user_id).localpart) + ) + + # Update the display name for the user. + path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id + channel = self.make_request("PUT", path, {"displayname": "John Doe"}) + self.assertEquals(channel.code, 200, channel.json_body) + + # Check that all the rooms have been sent a profile update into. + for room_id in room_ids: + path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( + room_id, + self.user_id, + ) + + channel = self.make_request("GET", path) + self.assertEquals(channel.code, 200) + + self.assertIn("displayname", channel.json_body) + self.assertEquals(channel.json_body["displayname"], "John Doe") + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit_idempotent(self): + """Tests that the room join endpoints remain idempotent despite rate-limiting + on room joins.""" + room_id = self.helper.create_room_as(self.user_id) + + # Let's test both paths to be sure. + paths_to_test = [ + "/_matrix/client/r0/rooms/%s/join", + "/_matrix/client/r0/join/%s", + ] + + for path in paths_to_test: + # Make sure we send more requests than the rate-limiting config would allow + # if all of these requests ended up joining the user to a room. + for _ in range(4): + channel = self.make_request("POST", path % room_id, {}) + self.assertEquals(channel.code, 200) + + @unittest.override_config( + { + "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}, + "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"], + "autocreate_auto_join_rooms": True, + }, + ) + def test_autojoin_rooms(self): + user_id = self.register_user("testuser", "password") + + # Check that the new user successfully joined the four rooms + rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 4) + + +class RoomMessagesTestCase(RoomBase): + """Tests /rooms/$room_id/messages/$user_id/$msg_id REST events.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_invalid_puts(self): + path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) + # missing keys or invalid json + channel = self.make_request("PUT", path, b"{}") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b'{"_name":"bo"}') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b'{"nao') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b"text only") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b"") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + def test_rooms_messages_sent(self): + path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) + + content = b'{"body":"test","msgtype":{"type":"a"}}' + channel = self.make_request("PUT", path, content) + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + # custom message types + content = b'{"body":"test","msgtype":"test.custom.text"}' + channel = self.make_request("PUT", path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # m.text message type + path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) + content = b'{"body":"test2","msgtype":"m.text"}' + channel = self.make_request("PUT", path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + +class RoomInitialSyncTestCase(RoomBase): + """Tests /rooms/$room_id/initialSync.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + # create the room + self.room_id = self.helper.create_room_as(self.user_id) + + def test_initial_sync(self): + channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) + self.assertEquals(200, channel.code) + + self.assertEquals(self.room_id, channel.json_body["room_id"]) + self.assertEquals("join", channel.json_body["membership"]) + + # Room state is easier to assert on if we unpack it into a dict + state = {} + for event in channel.json_body["state"]: + if "state_key" not in event: + continue + t = event["type"] + if t not in state: + state[t] = [] + state[t].append(event) + + self.assertTrue("m.room.create" in state) + + self.assertTrue("messages" in channel.json_body) + self.assertTrue("chunk" in channel.json_body["messages"]) + self.assertTrue("end" in channel.json_body["messages"]) + + self.assertTrue("presence" in channel.json_body) + + presence_by_user = { + e["content"]["user_id"]: e for e in channel.json_body["presence"] + } + self.assertTrue(self.user_id in presence_by_user) + self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) + + +class RoomMessageListTestCase(RoomBase): + """Tests /rooms/$room_id/messages REST events.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_topo_token_is_accepted(self): + token = "t1-0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + ) + self.assertEquals(200, channel.code) + self.assertTrue("start" in channel.json_body) + self.assertEquals(token, channel.json_body["start"]) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("end" in channel.json_body) + + def test_stream_token_is_accepted_for_fwd_pagianation(self): + token = "s0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + ) + self.assertEquals(200, channel.code) + self.assertTrue("start" in channel.json_body) + self.assertEquals(token, channel.json_body["start"]) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("end" in channel.json_body) + + def test_room_messages_purge(self): + store = self.hs.get_datastore() + pagination_handler = self.hs.get_pagination_handler() + + # Send a first message in the room, which will be removed by the purge. + first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] + first_token = self.get_success( + store.get_topological_token_for_event(first_event_id) + ) + first_token_str = self.get_success(first_token.to_string(store)) + + # Send a second message in the room, which won't be removed, and which we'll + # use as the marker to purge events before. + second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] + second_token = self.get_success( + store.get_topological_token_for_event(second_event_id) + ) + second_token_str = self.get_success(second_token.to_string(store)) + + # Send a third event in the room to ensure we don't fall under any edge case + # due to our marker being the latest forward extremity in the room. + self.helper.send(self.room_id, "message 3") + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) + + # Purge every event before the second event. + purge_id = random_string(16) + pagination_handler._purges_by_id[purge_id] = PurgeStatus() + self.get_success( + pagination_handler._purge_history( + purge_id=purge_id, + room_id=self.room_id, + token=second_token_str, + delete_local_events=True, + ) + ) + + # Check that we only get the second message through /message now that the first + # has been purged. + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) + + # Check that we get no event, but also no error, when querying /messages with + # the token that was pointing at the first event, because we don't have it + # anymore. + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % ( + self.room_id, + first_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) + + +class RoomSearchTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + user_id = True + hijack_auth = False + + def prepare(self, reactor, clock, hs): + + # Register the user who does the searching + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + # Register the user who sends the message + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + # Create a room + self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) + + # Invite the other person + self.helper.invite( + room=self.room, + src=self.user_id, + tok=self.access_token, + targ=self.other_user_id, + ) + + # The other user joins + self.helper.join( + room=self.room, user=self.other_user_id, tok=self.other_access_token + ) + + def test_finds_message(self): + """ + The search functionality will search for content in messages if asked to + do so. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": {"keys": ["content.body"], "search_term": "Hi"} + } + }, + ) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # No context was requested, so we should get none. + self.assertEqual(results["results"][0]["context"], {}) + + def test_include_context(self): + """ + When event_context includes include_profile, profile information will be + included in the search response. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": { + "keys": ["content.body"], + "search_term": "Hi", + "event_context": {"include_profile": True}, + } + } + }, + ) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # We should get context info, like the two users, and the display names. + context = results["results"][0]["context"] + self.assertEqual(len(context["profile_info"].keys()), 2) + self.assertEqual( + context["profile_info"][self.other_user_id]["displayname"], "otheruser" + ) + + +class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + self.url = b"/_matrix/client/r0/publicRooms" + + config = self.default_config() + config["allow_public_rooms_without_auth"] = False + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def test_restricted_no_auth(self): + channel = self.make_request("GET", self.url) + self.assertEqual(channel.code, 401, channel.result) + + def test_restricted_auth(self): + self.register_user("user", "pass") + tok = self.login("user", "pass") + + channel = self.make_request("GET", self.url, access_token=tok) + self.assertEqual(channel.code, 200, channel.result) + + +class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): + """Test that we correctly fallback to local filtering if a remote server + doesn't support search. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(federation_client=Mock()) + + def prepare(self, reactor, clock, hs): + self.register_user("user", "pass") + self.token = self.login("user", "pass") + + self.federation_client = hs.get_federation_client() + + def test_simple(self): + "Simple test for searching rooms over federation" + self.federation_client.get_public_rooms.side_effect = ( + lambda *a, **k: defer.succeed({}) + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_called_once_with( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ) + + def test_fallback(self): + "Test that searching public rooms over federation falls back if it gets a 404" + + # The `get_public_rooms` should be called again if the first call fails + # with a 404, when using search filters. + self.federation_client.get_public_rooms.side_effect = ( + HttpResponseException(404, "Not Found", b""), + defer.succeed({}), + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_has_calls( + [ + call( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ), + call( + "testserv", + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ), + ] + ) + + +class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["allow_per_room_profiles"] = False + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + + # Set a profile for the test user + self.displayname = "test user" + data = {"displayname": self.displayname} + request_data = json.dumps(data) + channel = self.make_request( + "PUT", + "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), + request_data, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_per_room_profile_forbidden(self): + data = {"membership": "join", "displayname": "other test user"} + request_data = json.dumps(data) + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" + % (self.room_id, self.user_id), + request_data, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res_displayname = channel.json_body["content"]["displayname"] + self.assertEqual(res_displayname, self.displayname, channel.result) + + +class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): + """Tests that clients can add a "reason" field to membership events and + that they get correctly added to the generated events and propagated. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + def test_join_reason(self): + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/join", + content={"reason": reason}, + access_token=self.second_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_leave_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/leave", + content={"reason": reason}, + access_token=self.second_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_kick_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/kick", + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.second_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_ban_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/ban", + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_unban_reason(self): + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/unban", + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_invite_reason(self): + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/invite", + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_reject_invite_reason(self): + self.helper.invite( + self.room_id, + src=self.creator, + targ=self.second_user_id, + tok=self.creator_tok, + ) + + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/leave", + content={"reason": reason}, + access_token=self.second_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def _check_for_reason(self, reason): + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( + self.room_id, self.second_user_id + ), + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + event_content = channel.json_body + + self.assertEqual(event_content.get("reason"), reason, channel.result) + + +class LabelsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + # Filter that should only catch messages with the label "#fun". + FILTER_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + # Filter that should only catch messages without the label "#fun". + FILTER_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + # Filter that should only catch messages with the label "#work" but without the label + # "#notfun". + FILTER_LABELS_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_context_filter_labels(self): + """Test that we can filter by a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "with right label", events_before[0] + ) + + events_after = channel.json_body["events_before"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with right label", events_after[0] + ) + + def test_context_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "without label", events_before[0] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 2, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + self.assertEqual( + events_after[1]["content"]["body"], "with two wrong labels", events_after[1] + ) + + def test_context_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /context request. + """ + event_id = self._send_labelled_messages_in_room() + + channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 0, [event["content"] for event in events_before] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + + def test_messages_filter_labels(self): + """Test that we can filter by a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), + ) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_messages_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), + ) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 4, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "without label", events[1]) + self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2]) + self.assertEqual( + events[3]["content"]["body"], "with two wrong labels", events[3] + ) + + def test_messages_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /messages request. + """ + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % ( + self.room_id, + self.tok, + token, + json.dumps(self.FILTER_LABELS_NOT_LABELS), + ), + ) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def test_search_filter_labels(self): + """Test that we can filter by a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 2, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with right label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "with right label", + results[1]["result"]["content"]["body"], + ) + + def test_search_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 4, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "without label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "without label", + results[1]["result"]["content"]["body"], + ) + self.assertEqual( + results[2]["result"]["content"]["body"], + "with wrong label", + results[2]["result"]["content"]["body"], + ) + self.assertEqual( + results[3]["result"]["content"]["body"], + "with two wrong labels", + results[3]["result"]["content"]["body"], + ) + + def test_search_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /search request. + """ + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 1, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with wrong label", + results[0]["result"]["content"]["body"], + ) + + def _send_labelled_messages_in_room(self): + """Sends several messages to a room with different labels (or without any) to test + filtering by label. + Returns: + The ID of the event to use if we're testing filtering on /context. + """ + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + # Return this event's ID when we test filtering in /context requests. + event_id = res["event_id"] + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + return event_id + + +class ContextTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + account.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.tok = self.login("user", "password") + self.room_id = self.helper.create_room_as( + self.user_id, tok=self.tok, is_public=False + ) + + self.other_user_id = self.register_user("user2", "password") + self.other_tok = self.login("user2", "password") + + self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) + self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) + + def test_erased_sender(self): + """Test that an erasure request results in the requester's events being hidden + from any new member of the room. + """ + + # Send a bunch of events in the room. + + self.helper.send(self.room_id, "message 1", tok=self.tok) + self.helper.send(self.room_id, "message 2", tok=self.tok) + event_id = self.helper.send(self.room_id, "message 3", tok=self.tok)["event_id"] + self.helper.send(self.room_id, "message 4", tok=self.tok) + self.helper.send(self.room_id, "message 5", tok=self.tok) + + # Check that we can still see the messages before the erasure request. + + channel = self.make_request( + "GET", + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertEqual( + events_before[0].get("content", {}).get("body"), + "message 2", + events_before[0], + ) + self.assertEqual( + events_before[1].get("content", {}).get("body"), + "message 1", + events_before[1], + ) + + self.assertEqual( + channel.json_body["event"].get("content", {}).get("body"), + "message 3", + channel.json_body["event"], + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertEqual( + events_after[0].get("content", {}).get("body"), + "message 4", + events_after[0], + ) + self.assertEqual( + events_after[1].get("content", {}).get("body"), + "message 5", + events_after[1], + ) + + # Deactivate the first account and erase the user's data. + + deactivate_account_handler = self.hs.get_deactivate_account_handler() + self.get_success( + deactivate_account_handler.deactivate_account( + self.user_id, True, create_requester(self.user_id) + ) + ) + + # Invite another user in the room. This is needed because messages will be + # pruned only if the user wasn't a member of the room when the messages were + # sent. + + invited_user_id = self.register_user("user3", "password") + invited_tok = self.login("user3", "password") + + self.helper.invite( + self.room_id, self.other_user_id, invited_user_id, tok=self.other_tok + ) + self.helper.join(self.room_id, invited_user_id, tok=invited_tok) + + # Check that a user that joined the room after the erasure request can't see + # the messages anymore. + + channel = self.make_request( + "GET", + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=invited_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertDictEqual(events_before[0].get("content"), {}, events_before[0]) + self.assertDictEqual(events_before[1].get("content"), {}, events_before[1]) + + self.assertDictEqual( + channel.json_body["event"].get("content"), {}, channel.json_body["event"] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) + self.assertEqual(events_after[1].get("content"), {}, events_after[1]) + + +class RoomAliasListTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + def test_no_aliases(self): + res = self._get_aliases(self.room_owner_tok) + self.assertEqual(res["aliases"], []) + + def test_not_in_room(self): + self.register_user("user", "test") + user_tok = self.login("user", "test") + res = self._get_aliases(user_tok, expected_code=403) + self.assertEqual(res["errcode"], "M_FORBIDDEN") + + def test_admin_user(self): + alias1 = self._random_alias() + self._set_alias_via_directory(alias1) + + self.register_user("user", "test", admin=True) + user_tok = self.login("user", "test") + + res = self._get_aliases(user_tok) + self.assertEqual(res["aliases"], [alias1]) + + def test_with_aliases(self): + alias1 = self._random_alias() + alias2 = self._random_alias() + + self._set_alias_via_directory(alias1) + self._set_alias_via_directory(alias2) + + res = self._get_aliases(self.room_owner_tok) + self.assertEqual(set(res["aliases"]), {alias1, alias2}) + + def test_peekable_room(self): + alias1 = self._random_alias() + self._set_alias_via_directory(alias1) + + self.helper.send_state( + self.room_id, + EventTypes.RoomHistoryVisibility, + body={"history_visibility": "world_readable"}, + tok=self.room_owner_tok, + ) + + self.register_user("user", "test") + user_tok = self.login("user", "test") + + res = self._get_aliases(user_tok) + self.assertEqual(res["aliases"], [alias1]) + + def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/aliases" % (self.room_id,), + access_token=access_token, + ) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + if expected_code == 200: + self.assertIsInstance(res["aliases"], list) + return res + + def _random_alias(self) -> str: + return RoomAlias(random_string(5), self.hs.hostname).to_string() + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.assertEqual(channel.code, expected_code, channel.result) + + +class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.alias = "#alias:test" + self._set_alias_via_directory(self.alias) + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.assertEqual(channel.code, expected_code, channel.result) + + def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + channel = self.make_request( + "GET", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + access_token=self.room_owner_tok, + ) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + channel = self.make_request( + "PUT", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + json.dumps(content), + access_token=self.room_owner_tok, + ) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def test_canonical_alias(self): + """Test a basic alias message.""" + # There is no canonical alias to start with. + self._get_canonical_alias(expected_code=404) + + # Create an alias. + self._set_canonical_alias({"alias": self.alias}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + # Now remove the alias. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alt_aliases(self): + """Test a canonical alias message with alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alias_alt_aliases(self): + """Test a canonical alias message with an alias and alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alias and alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_partial_modify(self): + """Test removing only the alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({"alias": self.alias}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + def test_add_alias(self): + """Test removing only the alt_aliases.""" + # Create an additional alias. + second_alias = "#second:test" + self._set_alias_via_directory(second_alias) + + # Add the canonical alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Then add the second alias. + self._set_canonical_alias( + {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual( + res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + def test_bad_data(self): + """Invalid data for alt_aliases should cause errors.""" + self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": None}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 0}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 1}, expected_code=400) + self._set_canonical_alias({"alt_aliases": False}, expected_code=400) + self._set_canonical_alias({"alt_aliases": True}, expected_code=400) + self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) + + def test_bad_alias(self): + """An alias which does not point to the room raises a SynapseError.""" + self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py new file mode 100644 index 0000000000..6db7062a8e --- /dev/null +++ b/tests/rest/client/test_sendtodevice.py @@ -0,0 +1,200 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest import admin +from synapse.rest.client import login, sendtodevice, sync + +from tests.unittest import HomeserverTestCase, override_config + + +class SendToDeviceTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + sendtodevice.register_servlets, + sync.register_servlets, + ] + + def test_user_to_user(self): + """A to-device message from one user to another should get delivered""" + + user1 = self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # send the message + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.test/1234", + content={"messages": {user2: {"d2": test_msg}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # check it appears + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + expected_result = { + "events": [ + { + "sender": user1, + "type": "m.test", + "content": test_msg, + } + ] + } + self.assertEqual(channel.json_body["to_device"], expected_result) + + # it should re-appear if we do another sync + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["to_device"], expected_result) + + # it should *not* appear if we do an incremental sync + sync_token = channel.json_body["next_batch"] + channel = self.make_request( + "GET", f"/sync?since={sync_token}", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) + + @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) + def test_local_room_key_request(self): + """m.room_key_request has special-casing; test from local user""" + user1 = self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # send three messages + for i in range(3): + chan = self.make_request( + "PUT", + f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}", + content={"messages": {user2: {"d2": {"idx": i}}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # now sync: we should get two of the three + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 2) + for i in range(2): + self.assertEqual( + msgs[i], + {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}}, + ) + sync_token = channel.json_body["next_batch"] + + # ... time passes + self.reactor.advance(1) + + # and we can send more messages + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/3", + content={"messages": {user2: {"d2": {"idx": 3}}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # ... which should arrive + channel = self.make_request( + "GET", f"/sync?since={sync_token}", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 1) + self.assertEqual( + msgs[0], + {"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}}, + ) + + @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) + def test_remote_room_key_request(self): + """m.room_key_request has special-casing; test from remote user""" + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + federation_registry = self.hs.get_federation_registry() + + # send three messages + for i in range(3): + self.get_success( + federation_registry.on_edu( + "m.direct_to_device", + "remote_server", + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "messages": {user2: {"d2": {"idx": i}}}, + "message_id": f"{i}", + }, + ) + ) + + # now sync: we should get two of the three + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 2) + for i in range(2): + self.assertEqual( + msgs[i], + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "content": {"idx": i}, + }, + ) + sync_token = channel.json_body["next_batch"] + + # ... time passes + self.reactor.advance(1) + + # and we can send more messages + self.get_success( + federation_registry.on_edu( + "m.direct_to_device", + "remote_server", + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "messages": {user2: {"d2": {"idx": 3}}}, + "message_id": "3", + }, + ) + ) + + # ... which should arrive + channel = self.make_request( + "GET", f"/sync?since={sync_token}", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 1) + self.assertEqual( + msgs[0], + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "content": {"idx": 3}, + }, + ) diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py new file mode 100644 index 0000000000..283eccd53f --- /dev/null +++ b/tests/rest/client/test_shared_rooms.py @@ -0,0 +1,142 @@ +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse.rest.admin +from synapse.rest.client import login, room, shared_rooms + +from tests import unittest +from tests.server import FakeChannel + + +class UserSharedRoomsTest(unittest.HomeserverTestCase): + """ + Tests the UserSharedRoomsServlet. + """ + + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + shared_rooms.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["update_user_directory"] = True + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + def _get_shared_rooms(self, token, other_user) -> FakeChannel: + return self.make_request( + "GET", + "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" + % other_user, + access_token=token, + ) + + def test_shared_room_list_public(self): + """ + A room should show up in the shared list of rooms between two users + if it is public. + """ + self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) + + def test_shared_room_list_private(self): + """ + A room should show up in the shared list of rooms between two users + if it is private. + """ + self._check_shared_rooms_with( + room_one_is_public=False, room_two_is_public=False + ) + + def test_shared_room_list_mixed(self): + """ + The shared room list between two users should contain both public and private + rooms. + """ + self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) + + def _check_shared_rooms_with( + self, room_one_is_public: bool, room_two_is_public: bool + ): + """Checks that shared public or private rooms between two users appear in + their shared room lists + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # Create a room. user1 invites user2, who joins + room_id_one = self.helper.create_room_as( + u1, is_public=room_one_is_public, tok=u1_token + ) + self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_one, user=u2, tok=u2_token) + + # Check shared rooms from user1's perspective. + # We should see the one room in common + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room_id_one) + + # Create another room and invite user2 to it + room_id_two = self.helper.create_room_as( + u1, is_public=room_two_is_public, tok=u1_token + ) + self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_two, user=u2, tok=u2_token) + + # Check shared rooms again. We should now see both rooms. + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 2) + for room_id_id in channel.json_body["joined"]: + self.assertIn(room_id_id, [room_id_one, room_id_two]) + + def test_shared_room_list_after_leave(self): + """ + A room should no longer be considered shared if the other + user has left it. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Assert user directory is not empty + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room) + + self.helper.leave(room, user=u1, tok=u1_token) + + # Check user1's view of shared rooms with user2 + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 0) + + # Check user2's view of shared rooms with user1 + channel = self._get_shared_rooms(u2_token, u1) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py new file mode 100644 index 0000000000..95be369d4b --- /dev/null +++ b/tests/rest/client/test_sync.py @@ -0,0 +1,686 @@ +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +import synapse.rest.admin +from synapse.api.constants import ( + EventContentFields, + EventTypes, + ReadReceiptEventFields, + RelationTypes, +) +from synapse.rest.client import knock, login, read_marker, receipts, room, sync + +from tests import unittest +from tests.federation.transport.test_knocking import ( + KnockingStrippedStateEventHelperMixin, +) +from tests.server import TimedOutException +from tests.unittest import override_config + + +class FilterTestCase(unittest.HomeserverTestCase): + + user_id = "@apple:test" + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def test_sync_argless(self): + channel = self.make_request("GET", "/sync") + + self.assertEqual(channel.code, 200) + self.assertIn("next_batch", channel.json_body) + + +class SyncFilterTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def test_sync_filter_labels(self): + """Test that we can filter by a label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_sync_filter_not_labels(self): + """Test that we can filter by the absence of a label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 3, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) + self.assertEqual( + events[2]["content"]["body"], "with two wrong labels", events[2] + ) + + def test_sync_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def _test_sync_filter_labels(self, sync_filter): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + room_id = self.helper.create_room_as(user_id, tok=tok) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + channel = self.make_request( + "GET", "/sync?filter=%s" % sync_filter, access_token=tok + ) + self.assertEqual(channel.code, 200, channel.result) + + return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] + + +class SyncTypingTests(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + user_id = True + hijack_auth = False + + def test_sync_backwards_typing(self): + """ + If the typing serial goes backwards and the typing handler is then reset + (such as when the master restarts and sets the typing serial to 0), we + do not incorrectly return typing information that had a serial greater + than the now-reset serial. + """ + typing_url = "/rooms/%s/typing/%s?access_token=%s" + sync_url = "/sync?timeout=3000000&access_token=%s&since=%s" + + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # Invite the other person + self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # The other user sends some messages + self.helper.send(room, body="Hi!", tok=other_access_token) + self.helper.send(room, body="There!", tok=other_access_token) + + # Start typing. + channel = self.make_request( + "PUT", + typing_url % (room, other_user_id, other_access_token), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,)) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # Stop typing. + channel = self.make_request( + "PUT", + typing_url % (room, other_user_id, other_access_token), + b'{"typing": false}', + ) + self.assertEquals(200, channel.code) + + # Start typing. + channel = self.make_request( + "PUT", + typing_url % (room, other_user_id, other_access_token), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + # Should return immediately + channel = self.make_request("GET", sync_url % (access_token, next_batch)) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # Reset typing serial back to 0, as if the master had. + typing = self.hs.get_typing_handler() + typing._latest_room_serial = 0 + + # Since it checks the state token, we need some state to update to + # invalidate the stream token. + self.helper.send(room, body="There!", tok=other_access_token) + + channel = self.make_request("GET", sync_url % (access_token, next_batch)) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # This should time out! But it does not, because our stream token is + # ahead, and therefore it's saying the typing (that we've actually + # already seen) is new, since it's got a token above our new, now-reset + # stream token. + channel = self.make_request("GET", sync_url % (access_token, next_batch)) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # Clear the typing information, so that it doesn't think everything is + # in the future. + typing._reset() + + # Now it SHOULD fail as it never completes! + with self.assertRaises(TimedOutException): + self.make_request("GET", sync_url % (access_token, next_batch)) + + +class SyncKnockTestCase( + unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin +): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + knock.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to create the room to knock on). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll knock on. + self.room_id = self.helper.create_room_as( + self.user_id, + is_public=False, + room_version="7", + tok=self.tok, + ) + + # Register the second user (used to knock on the room). + self.knocker = self.register_user("knocker", "monkey") + self.knocker_tok = self.login("knocker", "monkey") + + # Perform an initial sync for the knocking user. + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] + + # Set up some room state to test with. + self.expected_room_state = self.send_example_state_events_to_room( + hs, self.room_id, self.user_id + ) + + @override_config({"experimental_features": {"msc2403_enabled": True}}) + def test_knock_room_state(self): + """Tests that /sync returns state from a room after knocking on it.""" + # Knock on a room + channel = self.make_request( + "POST", + "/_matrix/client/r0/knock/%s" % (self.room_id,), + b"{}", + self.knocker_tok, + ) + self.assertEquals(200, channel.code, channel.result) + + # We expect to see the knock event in the stripped room state later + self.expected_room_state[EventTypes.Member] = { + "content": {"membership": "knock", "displayname": "knocker"}, + "state_key": "@knocker:test", + } + + # Check that /sync includes stripped state from the room + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.knocker_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Extract the stripped room state events from /sync + knock_entry = channel.json_body["rooms"]["knock"] + room_state_events = knock_entry[self.room_id]["knock_state"]["events"] + + # Validate that the knock membership event came last + self.assertEqual(room_state_events[-1]["type"], EventTypes.Member) + + # Validate the stripped room state events + self.check_knock_room_state_against_room_state( + room_state_events, self.expected_room_state + ) + + +class ReadReceiptsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + receipts.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Join the second user + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + + @override_config({"experimental_features": {"msc2285_enabled": True}}) + def test_hidden_read_receipts(self): + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt to tell the server the first user's message was read + body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + body, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that the first user can't see the other user's hidden read receipt + self.assertEqual(self._get_read_receipt(), None) + + def test_read_receipt_with_empty_body(self): + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt for this message with an empty body + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + def _get_read_receipt(self): + """Syncs and returns the read receipt.""" + + # Checks if event is a read receipt + def is_read_receipt(event): + return event["type"] == "m.receipt" + + # Sync + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] + + # Return the read receipt + ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ + "ephemeral" + ]["events"] + return next(filter(is_read_receipt, ephemeral_events), None) + + +class UnreadMessagesTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + read_marker.register_servlets, + room.register_servlets, + sync.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to check the unread counts). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll check unread counts for. + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user (used to send events to the room). + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Change the power levels of the room so that the second user can send state + # events. + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + { + "users": {self.user_id: 100, self.user2: 100}, + "users_default": 0, + "events": { + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.history_visibility": 100, + "m.room.canonical_alias": 50, + "m.room.avatar": 50, + "m.room.tombstone": 100, + "m.room.server_acl": 100, + "m.room.encryption": 100, + }, + "events_default": 0, + "state_default": 50, + "ban": 50, + "kick": 50, + "redact": 50, + "invite": 0, + }, + tok=self.tok, + ) + + def test_unread_counts(self): + """Tests that /sync returns the right value for the unread count (MSC2654).""" + + # Check that our own messages don't increase the unread count. + self.helper.send(self.room_id, "hello", tok=self.tok) + self._check_unread_count(0) + + # Join the new user and check that this doesn't increase the unread count. + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + self._check_unread_count(0) + + # Check that the new user sending a message increases our unread count. + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + channel = self.make_request( + "POST", + "/rooms/%s/read_markers" % self.room_id, + body, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that hidden read receipts don't break unread counts + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + body, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that room name changes increase the unread counter. + self.helper.send_state( + self.room_id, + "m.room.name", + {"name": "my super room"}, + tok=self.tok2, + ) + self._check_unread_count(1) + + # Check that room topic changes increase the unread counter. + self.helper.send_state( + self.room_id, + "m.room.topic", + {"topic": "welcome!!!"}, + tok=self.tok2, + ) + self._check_unread_count(2) + + # Check that encrypted messages increase the unread counter. + self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) + self._check_unread_count(3) + + # Check that custom events with a body increase the unread counter. + self.helper.send_event( + self.room_id, + "org.matrix.custom_type", + {"body": "hello"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that edits don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": "hello", + "msgtype": "m.text", + "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + }, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that notices don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"body": "hello", "msgtype": "m.notice"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that tombstone events changes increase the unread counter. + self.helper.send_state( + self.room_id, + EventTypes.Tombstone, + {"replacement_room": "!someroom:test"}, + tok=self.tok2, + ) + self._check_unread_count(5) + + def _check_unread_count(self, expected_count: int): + """Syncs and compares the unread count with the expected value.""" + + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.tok, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + + room_entry = channel.json_body["rooms"]["join"][self.room_id] + self.assertEqual( + room_entry["org.matrix.msc2654.unread_count"], + expected_count, + room_entry, + ) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] + + +class SyncCacheTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def test_noop_sync_does_not_tightloop(self): + """If the sync times out, we shouldn't cache the result + + Essentially a regression test for #8518. + """ + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # we should immediately get an initial sync response + channel = self.make_request("GET", "/sync", access_token=self.tok) + self.assertEqual(channel.code, 200, channel.json_body) + + # now, make an incremental sync request, with a timeout + next_batch = channel.json_body["next_batch"] + channel = self.make_request( + "GET", + f"/sync?since={next_batch}&timeout=10000", + access_token=self.tok, + await_result=False, + ) + # that should block for 10 seconds + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=9900) + channel.await_result(timeout_ms=200) + self.assertEqual(channel.code, 200, channel.json_body) + + # we expect the next_batch in the result to be the same as before + self.assertEqual(channel.json_body["next_batch"], next_batch) + + # another incremental sync should also block. + channel = self.make_request( + "GET", + f"/sync?since={next_batch}&timeout=10000", + access_token=self.tok, + await_result=False, + ) + # that should block for 10 seconds + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=9900) + channel.await_result(timeout_ms=200) + self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py new file mode 100644 index 0000000000..b54b004733 --- /dev/null +++ b/tests/rest/client/test_typing.py @@ -0,0 +1,121 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests REST events for /rooms paths.""" + +from unittest.mock import Mock + +from synapse.rest.client import room +from synapse.types import UserID + +from tests import unittest + +PATH_PREFIX = "/_matrix/client/api/v1" + + +class RoomTypingTestCase(unittest.HomeserverTestCase): + """Tests /rooms/$room_id/typing/$user_id REST API.""" + + user_id = "@sid:red" + + user = UserID.from_string(user_id) + servlets = [room.register_servlets] + + def make_homeserver(self, reactor, clock): + + hs = self.setup_test_homeserver( + "red", + federation_http_client=None, + federation_client=Mock(), + ) + + self.event_source = hs.get_event_sources().sources["typing"] + + hs.get_federation_handler = Mock() + + async def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.auth_user_id), + "token_id": 1, + "is_guest": False, + } + + hs.get_auth().get_user_by_access_token = get_user_by_access_token + + async def _insert_client_ip(*args, **kwargs): + return None + + hs.get_datastore().insert_client_ip = _insert_client_ip + + return hs + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + # Need another user to make notifications actually work + self.helper.join(self.room_id, user="@jim:red") + + def test_set_typing(self): + channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + self.assertEquals(self.event_source.get_current_key(), 1) + events = self.get_success( + self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + ) + self.assertEquals( + events[0], + [ + { + "type": "m.typing", + "room_id": self.room_id, + "content": {"user_ids": [self.user_id]}, + } + ], + ) + + def test_set_not_typing(self): + channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + b'{"typing": false}', + ) + self.assertEquals(200, channel.code) + + def test_typing_timeout(self): + channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + self.assertEquals(self.event_source.get_current_key(), 1) + + self.reactor.advance(36) + + self.assertEquals(self.event_source.get_current_key(), 2) + + channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + self.assertEquals(self.event_source.get_current_key(), 3) diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py new file mode 100644 index 0000000000..72f976d8e2 --- /dev/null +++ b/tests/rest/client/test_upgrade_room.py @@ -0,0 +1,159 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.rest import admin +from synapse.rest.client import login, room, room_upgrade_rest_servlet + +from tests import unittest +from tests.server import FakeChannel + + +class UpgradeRoomTest(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + room_upgrade_rest_servlet.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + self.creator = self.register_user("creator", "pass") + self.creator_token = self.login(self.creator, "pass") + + self.other = self.register_user("user", "pass") + self.other_token = self.login(self.other, "pass") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token) + self.helper.join(self.room_id, self.other, tok=self.other_token) + + def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel: + # We never want a cached response. + self.reactor.advance(5 * 60 + 1) + + return self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id, + # This will upgrade a room to the same version, but that's fine. + content={"new_version": DEFAULT_ROOM_VERSION}, + access_token=token or self.creator_token, + ) + + def test_upgrade(self): + """ + Upgrading a room should work fine. + """ + channel = self._upgrade_room() + self.assertEquals(200, channel.code, channel.result) + self.assertIn("replacement_room", channel.json_body) + + def test_not_in_room(self): + """ + Upgrading a room should work fine. + """ + # THe user isn't in the room. + roomless = self.register_user("roomless", "pass") + roomless_token = self.login(roomless, "pass") + + channel = self._upgrade_room(roomless_token) + self.assertEquals(403, channel.code, channel.result) + + def test_power_levels(self): + """ + Another user can upgrade the room if their power level is increased. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["users"][self.other] = 100 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + def test_power_levels_user_default(self): + """ + Another user can upgrade the room if the default power level for users is increased. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["users_default"] = 100 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + def test_power_levels_tombstone(self): + """ + Another user can upgrade the room if they can send the tombstone event. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["events"]["m.room.tombstone"] = 0 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + self.assertNotIn(self.other, power_levels["users"]) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py new file mode 100644 index 0000000000..954ad1a1fd --- /dev/null +++ b/tests/rest/client/utils.py @@ -0,0 +1,654 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +import time +import urllib.parse +from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union +from unittest.mock import patch + +import attr + +from twisted.web.resource import Resource +from twisted.web.server import Site + +from synapse.api.constants import Membership +from synapse.types import JsonDict + +from tests.server import FakeChannel, FakeSite, make_request +from tests.test_utils import FakeResponse +from tests.test_utils.html_parsers import TestHtmlParser + + +@attr.s +class RestHelper: + """Contains extra helper functions to quickly and clearly perform a given + REST action, which isn't the focus of the test. + """ + + hs = attr.ib() + site = attr.ib(type=Site) + auth_user_id = attr.ib() + + def create_room_as( + self, + room_creator: Optional[str] = None, + is_public: bool = True, + room_version: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = 200, + extra_content: Optional[Dict] = None, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ) -> str: + """ + Create a room. + + Args: + room_creator: The user ID to create the room with. + is_public: If True, the `visibility` parameter will be set to the + default (public). Otherwise, the `visibility` parameter will be set + to "private". + room_version: The room version to create the room as. Defaults to Synapse's + default room version. + tok: The access token to use in the request. + expect_code: The expected HTTP response code. + + Returns: + The ID of the newly created room. + """ + temp_id = self.auth_user_id + self.auth_user_id = room_creator + path = "/_matrix/client/r0/createRoom" + content = extra_content or {} + if not is_public: + content["visibility"] = "private" + if room_version: + content["room_version"] = room_version + if tok: + path = path + "?access_token=%s" % tok + + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + path, + json.dumps(content).encode("utf8"), + custom_headers=custom_headers, + ) + + assert channel.result["code"] == b"%d" % expect_code, channel.result + self.auth_user_id = temp_id + + if expect_code == 200: + return channel.json_body["room_id"] + + def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=src, + targ=targ, + tok=tok, + membership=Membership.INVITE, + expect_code=expect_code, + ) + + def join(self, room=None, user=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.JOIN, + expect_code=expect_code, + ) + + def leave(self, room=None, user=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.LEAVE, + expect_code=expect_code, + ) + + def change_membership( + self, + room: str, + src: str, + targ: str, + membership: str, + extra_data: Optional[dict] = None, + tok: Optional[str] = None, + expect_code: int = 200, + ) -> None: + """ + Send a membership state event into a room. + + Args: + room: The ID of the room to send to + src: The mxid of the event sender + targ: The mxid of the event's target. The state key + membership: The type of membership event + extra_data: Extra information to include in the content of the event + tok: The user access token to use + expect_code: The expected HTTP response code + """ + temp_id = self.auth_user_id + self.auth_user_id = src + + path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ) + if tok: + path = path + "?access_token=%s" % tok + + data = {"membership": membership} + data.update(extra_data or {}) + + channel = make_request( + self.hs.get_reactor(), + self.site, + "PUT", + path, + json.dumps(data).encode("utf8"), + ) + + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + self.auth_user_id = temp_id + + def send( + self, + room_id, + body=None, + txn_id=None, + tok=None, + expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): + if body is None: + body = "body_text_here" + + content = {"msgtype": "m.text", "body": body} + + return self.send_event( + room_id, + "m.room.message", + content, + txn_id, + tok, + expect_code, + custom_headers=custom_headers, + ) + + def send_event( + self, + room_id, + type, + content: Optional[dict] = None, + txn_id=None, + tok=None, + expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + + path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id) + if tok: + path = path + "?access_token=%s" % tok + + channel = make_request( + self.hs.get_reactor(), + self.site, + "PUT", + path, + json.dumps(content or {}).encode("utf8"), + custom_headers=custom_headers, + ) + + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + return channel.json_body + + def _read_write_state( + self, + room_id: str, + event_type: str, + body: Optional[Dict[str, Any]], + tok: str, + expect_code: int = 200, + state_key: str = "", + method: str = "GET", + ) -> Dict: + """Read or write some state from a given room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + If None, the request to the server will have an empty body + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + method: "GET" or "PUT" for reading or writing state, respectively + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( + room_id, + event_type, + state_key, + ) + if tok: + path = path + "?access_token=%s" % tok + + # Set request body if provided + content = b"" + if body is not None: + content = json.dumps(body).encode("utf8") + + channel = make_request(self.hs.get_reactor(), self.site, method, path, content) + + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + return channel.json_body + + def get_state( + self, + room_id: str, + event_type: str, + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Gets some state from a room + + Args: + room_id: + event_type: The type of state event + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, None, tok, expect_code, state_key, method="GET" + ) + + def send_state( + self, + room_id: str, + event_type: str, + body: Dict[str, Any], + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Set some state in a room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, body, tok, expect_code, state_key, method="PUT" + ) + + def upload_media( + self, + resource: Resource, + image_data: bytes, + tok: str, + filename: str = "test.png", + expect_code: int = 200, + ) -> dict: + """Upload a piece of test media to the media repo + Args: + resource: The resource that will handle the upload request + image_data: The image data to upload + tok: The user token to use during the upload + filename: The filename of the media to be uploaded + expect_code: The return code to expect from attempting to upload the media + """ + image_length = len(image_data) + path = "/_matrix/media/r0/upload?filename=%s" % (filename,) + channel = make_request( + self.hs.get_reactor(), + FakeSite(resource), + "POST", + path, + content=image_data, + access_token=tok, + custom_headers=[(b"Content-Length", str(image_length))], + ) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + return channel.json_body + + def login_via_oidc(self, remote_user_id: str) -> JsonDict: + """Log in (as a new user) via OIDC + + Returns the result of the final token login. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + """ + client_redirect_url = "https://x" + channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) + + # expect a confirmation page + assert channel.code == 200, channel.result + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.text_body, + ) + assert m, channel.text_body + login_token = m.group(1) + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token and device id. + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + assert channel.code == 200 + return channel.json_body + + def auth_via_oidc( + self, + user_info_dict: JsonDict, + client_redirect_url: Optional[str] = None, + ui_auth_session_id: Optional[str] = None, + ) -> FakeChannel: + """Perform an OIDC authentication flow via a mock OIDC provider. + + This can be used for either login or user-interactive auth. + + Starts by making a request to the relevant synapse redirect endpoint, which is + expected to serve a 302 to the OIDC provider. We then make a request to the + OIDC callback endpoint, intercepting the HTTP requests that will get sent back + to the OIDC provider. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + + Args: + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": ""}'. + client_redirect_url: for a login flow, the client redirect URL to pass to + the login redirect endpoint + ui_auth_session_id: if set, we will perform a UI Auth flow. The session id + of the UI auth. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + Note that the response code may be a 200, 302 or 400 depending on how things + went. + """ + + cookies = {} + + # if we're doing a ui auth, hit the ui auth redirect endpoint + if ui_auth_session_id: + # can't set the client redirect url for UI Auth + assert client_redirect_url is None + oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + else: + # otherwise, hit the login redirect endpoint + oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + + # we now have a URI for the OIDC IdP, but we skip that and go straight + # back to synapse's OIDC callback resource. However, we do need the "state" + # param that synapse passes to the IdP via query params, as well as the cookie + # that synapse passes to the client. + + oauth_uri_path, _ = oauth_uri.split("?", 1) + assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( + "unexpected SSO URI " + oauth_uri_path + ) + return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + + def complete_oidc_auth( + self, + oauth_uri: str, + cookies: Mapping[str, str], + user_info_dict: JsonDict, + ) -> FakeChannel: + """Mock out an OIDC authentication flow + + Assumes that an OIDC auth has been initiated by one of initiate_sso_login or + initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to + Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get + sent back to the OIDC provider. + + Requires the OIDC callback resource to be mounted at the normal place. + + Args: + oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, + from initiate_sso_login or initiate_sso_ui_auth). + cookies: the cookies set by synapse's redirect endpoint, which will be + sent back to the callback endpoint. + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": ""}'. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + """ + _, oauth_uri_qs = oauth_uri.split("?", 1) + params = urllib.parse.parse_qs(oauth_uri_qs) + callback_uri = "%s?%s" % ( + urllib.parse.urlparse(params["redirect_uri"][0]).path, + urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), + ) + + # before we hit the callback uri, stub out some methods in the http client so + # that we don't have to handle full HTTPS requests. + # (expected url, json response) pairs, in the order we expect them. + expected_requests = [ + # first we get a hit to the token endpoint, which we tell to return + # a dummy OIDC access token + (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), + # and then one to the user_info endpoint, which returns our remote user id. + (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), + ] + + async def mock_req(method: str, uri: str, data=None, headers=None): + (expected_uri, resp_obj) = expected_requests.pop(0) + assert uri == expected_uri + resp = FakeResponse( + code=200, + phrase=b"OK", + body=json.dumps(resp_obj).encode("utf-8"), + ) + return resp + + with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): + # now hit the callback URI with the right params and a made-up code + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + callback_uri, + custom_headers=[ + ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() + ], + ) + return channel + + def initiate_sso_login( + self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] + ) -> str: + """Make a request to the login-via-sso redirect endpoint, and return the target + + Assumes that exactly one SSO provider has been configured. Requires the login + servlet to be mounted. + + Args: + client_redirect_url: the client redirect URL to pass to the login redirect + endpoint + cookies: any cookies returned will be added to this dict + + Returns: + the URI that the client gets redirected to (ie, the SSO server) + """ + params = {} + if client_redirect_url: + params["redirectUrl"] = client_redirect_url + + # hit the redirect url (which should redirect back to the redirect url. This + # is the easiest way of figuring out what the Host header ought to be set to + # to keep Synapse happy. + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), + ) + assert channel.code == 302 + + # hit the redirect url again with the right Host header, which should now issue + # a cookie and redirect to the SSO provider. + location = channel.headers.getRawHeaders("Location")[0] + parts = urllib.parse.urlsplit(location) + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + urllib.parse.urlunsplit(("", "") + parts[2:]), + custom_headers=[ + ("Host", parts[1]), + ], + ) + + assert channel.code == 302 + channel.extract_cookies(cookies) + return channel.headers.getRawHeaders("Location")[0] + + def initiate_sso_ui_auth( + self, ui_auth_session_id: str, cookies: MutableMapping[str, str] + ) -> str: + """Make a request to the ui-auth-via-sso endpoint, and return the target + + Assumes that exactly one SSO provider has been configured. Requires the + AuthRestServlet to be mounted. + + Args: + ui_auth_session_id: the session id of the UI auth + cookies: any cookies returned will be added to this dict + + Returns: + the URI that the client gets linked to (ie, the SSO server) + """ + sso_redirect_endpoint = ( + "/_matrix/client/r0/auth/m.login.sso/fallback/web?" + + urllib.parse.urlencode({"session": ui_auth_session_id}) + ) + # hit the redirect url (which will issue a cookie and state) + channel = make_request( + self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint + ) + # that should serve a confirmation page + assert channel.code == 200, channel.text_body + channel.extract_cookies(cookies) + + # parse the confirmation page to fish out the link. + p = TestHtmlParser() + p.feed(channel.text_body) + p.close() + assert len(p.links) == 1, "not exactly one link in confirmation page" + oauth_uri = p.links[0] + return oauth_uri + + +# an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" +TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" +TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" +TEST_OIDC_CONFIG = { + "enabled": True, + "discover": False, + "issuer": "https://issuer.test", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, + "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, + "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, + "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, +} diff --git a/tests/rest/client/v1/__init__.py b/tests/rest/client/v1/__init__.py deleted file mode 100644 index 5e83dba2ed..0000000000 --- a/tests/rest/client/v1/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py deleted file mode 100644 index d2181ea907..0000000000 --- a/tests/rest/client/v1/test_directory.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from synapse.rest import admin -from synapse.rest.client import directory, login, room -from synapse.types import RoomAlias -from synapse.util.stringutils import random_string - -from tests import unittest -from tests.unittest import override_config - - -class DirectoryTestCase(unittest.HomeserverTestCase): - - servlets = [ - admin.register_servlets_for_client_rest_resource, - directory.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["require_membership_for_aliases"] = True - - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def prepare(self, reactor, clock, homeserver): - self.room_owner = self.register_user("room_owner", "test") - self.room_owner_tok = self.login("room_owner", "test") - - self.room_id = self.helper.create_room_as( - self.room_owner, tok=self.room_owner_tok - ) - - self.user = self.register_user("user", "test") - self.user_tok = self.login("user", "test") - - def test_state_event_not_in_room(self): - self.ensure_user_left_room() - self.set_alias_via_state_event(403) - - def test_directory_endpoint_not_in_room(self): - self.ensure_user_left_room() - self.set_alias_via_directory(403) - - def test_state_event_in_room_too_long(self): - self.ensure_user_joined_room() - self.set_alias_via_state_event(400, alias_length=256) - - def test_directory_in_room_too_long(self): - self.ensure_user_joined_room() - self.set_alias_via_directory(400, alias_length=256) - - @override_config({"default_room_version": 5}) - def test_state_event_user_in_v5_room(self): - """Test that a regular user can add alias events before room v6""" - self.ensure_user_joined_room() - self.set_alias_via_state_event(200) - - @override_config({"default_room_version": 6}) - def test_state_event_v6_room(self): - """Test that a regular user can *not* add alias events from room v6""" - self.ensure_user_joined_room() - self.set_alias_via_state_event(403) - - def test_directory_in_room(self): - self.ensure_user_joined_room() - self.set_alias_via_directory(200) - - def test_room_creation_too_long(self): - url = "/_matrix/client/r0/createRoom" - - # We use deliberately a localpart under the length threshold so - # that we can make sure that the check is done on the whole alias. - data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} - request_data = json.dumps(data) - channel = self.make_request( - "POST", url, request_data, access_token=self.user_tok - ) - self.assertEqual(channel.code, 400, channel.result) - - def test_room_creation(self): - url = "/_matrix/client/r0/createRoom" - - # Check with an alias of allowed length. There should already be - # a test that ensures it works in test_register.py, but let's be - # as cautious as possible here. - data = {"room_alias_name": random_string(5)} - request_data = json.dumps(data) - channel = self.make_request( - "POST", url, request_data, access_token=self.user_tok - ) - self.assertEqual(channel.code, 200, channel.result) - - def set_alias_via_state_event(self, expected_code, alias_length=5): - url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( - self.room_id, - self.hs.hostname, - ) - - data = {"aliases": [self.random_alias(alias_length)]} - request_data = json.dumps(data) - - channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok - ) - self.assertEqual(channel.code, expected_code, channel.result) - - def set_alias_via_directory(self, expected_code, alias_length=5): - url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) - data = {"room_id": self.room_id} - request_data = json.dumps(data) - - channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok - ) - self.assertEqual(channel.code, expected_code, channel.result) - - def random_alias(self, length): - return RoomAlias(random_string(length), self.hs.hostname).to_string() - - def ensure_user_left_room(self): - self.ensure_membership("leave") - - def ensure_user_joined_room(self): - self.ensure_membership("join") - - def ensure_membership(self, membership): - try: - if membership == "leave": - self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) - if membership == "join": - self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok) - except AssertionError: - # We don't care whether the leave request didn't return a 200 (e.g. - # if the user isn't already in the room), because we only want to - # make sure the user isn't in the room. - pass diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py deleted file mode 100644 index a90294003e..0000000000 --- a/tests/rest/client/v1/test_events.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" Tests REST events for /events paths.""" - -from unittest.mock import Mock - -import synapse.rest.admin -from synapse.rest.client import events, login, room - -from tests import unittest - - -class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): - """Tests event streaming (GET /events).""" - - servlets = [ - events.register_servlets, - room.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - - config = self.default_config() - config["enable_registration_captcha"] = False - config["enable_registration"] = True - config["auto_join_rooms"] = [] - - hs = self.setup_test_homeserver(config=config) - - hs.get_federation_handler = Mock() - - return hs - - def prepare(self, reactor, clock, hs): - - # register an account - self.user_id = self.register_user("sid1", "pass") - self.token = self.login(self.user_id, "pass") - - # register a 2nd account - self.other_user = self.register_user("other2", "pass") - self.other_token = self.login(self.other_user, "pass") - - def test_stream_basic_permissions(self): - # invalid token, expect 401 - # note: this is in violation of the original v1 spec, which expected - # 403. However, since the v1 spec no longer exists and the v1 - # implementation is now part of the r0 implementation, the newer - # behaviour is used instead to be consistent with the r0 spec. - # see issue #2602 - channel = self.make_request( - "GET", "/events?access_token=%s" % ("invalid" + self.token,) - ) - self.assertEquals(channel.code, 401, msg=channel.result) - - # valid token, expect content - channel = self.make_request( - "GET", "/events?access_token=%s&timeout=0" % (self.token,) - ) - self.assertEquals(channel.code, 200, msg=channel.result) - self.assertTrue("chunk" in channel.json_body) - self.assertTrue("start" in channel.json_body) - self.assertTrue("end" in channel.json_body) - - def test_stream_room_permissions(self): - room_id = self.helper.create_room_as(self.other_user, tok=self.other_token) - self.helper.send(room_id, tok=self.other_token) - - # invited to room (expect no content for room) - self.helper.invite( - room_id, src=self.other_user, targ=self.user_id, tok=self.other_token - ) - - # valid token, expect content - channel = self.make_request( - "GET", "/events?access_token=%s&timeout=0" % (self.token,) - ) - self.assertEquals(channel.code, 200, msg=channel.result) - - # We may get a presence event for ourselves down - self.assertEquals( - 0, - len( - [ - c - for c in channel.json_body["chunk"] - if not ( - c.get("type") == "m.presence" - and c["content"].get("user_id") == self.user_id - ) - ] - ), - ) - - # joined room (expect all content for room) - self.helper.join(room=room_id, user=self.user_id, tok=self.token) - - # left to room (expect no content for room) - - def TODO_test_stream_items(self): - # new user, no content - - # join room, expect 1 item (join) - - # send message, expect 2 items (join,send) - - # set topic, expect 3 items (join,send,topic) - - # someone else join room, expect 4 (join,send,topic,join) - - # someone else send message, expect 5 (join,send.topic,join,send) - - # someone else set topic, expect 6 (join,send,topic,join,send,topic) - pass - - -class GetEventsTestCase(unittest.HomeserverTestCase): - servlets = [ - events.register_servlets, - room.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - ] - - def prepare(self, hs, reactor, clock): - - # register an account - self.user_id = self.register_user("sid1", "pass") - self.token = self.login(self.user_id, "pass") - - self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) - - def test_get_event_via_events(self): - resp = self.helper.send(self.room_id, tok=self.token) - event_id = resp["event_id"] - - channel = self.make_request( - "GET", - "/events/" + event_id, - access_token=self.token, - ) - self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py deleted file mode 100644 index eba3552b19..0000000000 --- a/tests/rest/client/v1/test_login.py +++ /dev/null @@ -1,1345 +0,0 @@ -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import urllib.parse -from typing import Any, Dict, List, Optional, Union -from unittest.mock import Mock -from urllib.parse import urlencode - -import pymacaroons - -from twisted.web.resource import Resource - -import synapse.rest.admin -from synapse.appservice import ApplicationService -from synapse.rest.client import devices, login, logout, register -from synapse.rest.client.account import WhoamiRestServlet -from synapse.rest.synapse.client import build_synapse_client_resource_tree -from synapse.types import create_requester - -from tests import unittest -from tests.handlers.test_oidc import HAS_OIDC -from tests.handlers.test_saml import has_saml2 -from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG -from tests.test_utils.html_parsers import TestHtmlParser -from tests.unittest import HomeserverTestCase, override_config, skip_unless - -try: - import jwt - - HAS_JWT = True -except ImportError: - HAS_JWT = False - - -# synapse server name: used to populate public_baseurl in some tests -SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" - -# public_baseurl for some tests. It uses an http:// scheme because -# FakeChannel.isSecure() returns False, so synapse will see the requested uri as -# http://..., so using http in the public_baseurl stops Synapse trying to redirect to -# https://.... -BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) - -# CAS server used in some tests -CAS_SERVER = "https://fake.test" - -# just enough to tell pysaml2 where to redirect to -SAML_SERVER = "https://test.saml.server/idp/sso" -TEST_SAML_METADATA = """ - - - - - -""" % { - "SAML_SERVER": SAML_SERVER, -} - -LOGIN_URL = b"/_matrix/client/r0/login" -TEST_URL = b"/_matrix/client/r0/account/whoami" - -# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + -TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' - -# the query params in TEST_CLIENT_REDIRECT_URL -EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] - -# (possibly experimental) login flows we expect to appear in the list after the normal -# ones -ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] - - -class LoginRestServletTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - logout.register_servlets, - devices.register_servlets, - lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), - ] - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.enable_registration = True - self.hs.config.registrations_require_3pid = [] - self.hs.config.auto_join_rooms = [] - self.hs.config.enable_registration_captcha = False - - return self.hs - - @override_config( - { - "rc_login": { - "address": {"per_second": 0.17, "burst_count": 5}, - # Prevent the account login ratelimiter from raising first - # - # This is normally covered by the default test homeserver config - # which sets these values to 10000, but as we're overriding the entire - # rc_login dict here, we need to set this manually as well - "account": {"per_second": 10000, "burst_count": 10000}, - } - } - ) - def test_POST_ratelimiting_per_address(self): - # Create different users so we're sure not to be bothered by the per-user - # ratelimiter. - for i in range(0, 6): - self.register_user("kermit" + str(i), "monkey") - - for i in range(0, 6): - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower - # than 1min. - self.assertTrue(retry_after_ms < 6000) - - self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - @override_config( - { - "rc_login": { - "account": {"per_second": 0.17, "burst_count": 5}, - # Prevent the address login ratelimiter from raising first - # - # This is normally covered by the default test homeserver config - # which sets these values to 10000, but as we're overriding the entire - # rc_login dict here, we need to set this manually as well - "address": {"per_second": 10000, "burst_count": 10000}, - } - } - ) - def test_POST_ratelimiting_per_account(self): - self.register_user("kermit", "monkey") - - for i in range(0, 6): - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower - # than 1min. - self.assertTrue(retry_after_ms < 6000) - - self.reactor.advance(retry_after_ms / 1000.0) - - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - @override_config( - { - "rc_login": { - # Prevent the address login ratelimiter from raising first - # - # This is normally covered by the default test homeserver config - # which sets these values to 10000, but as we're overriding the entire - # rc_login dict here, we need to set this manually as well - "address": {"per_second": 10000, "burst_count": 10000}, - "failed_attempts": {"per_second": 0.17, "burst_count": 5}, - } - } - ) - def test_POST_ratelimiting_per_account_failed_attempts(self): - self.register_user("kermit", "monkey") - - for i in range(0, 6): - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "notamonkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"403", channel.result) - - # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower - # than 1min. - self.assertTrue(retry_after_ms < 6000) - - self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "notamonkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.result["code"], b"403", channel.result) - - @override_config({"session_lifetime": "24h"}) - def test_soft_logout(self): - self.register_user("kermit", "monkey") - - # we shouldn't be able to make requests without an access token - channel = self.make_request(b"GET", TEST_URL) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") - - # log in as normal - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.code, 200, channel.result) - access_token = channel.json_body["access_token"] - device_id = channel.json_body["device_id"] - - # we should now be able to make requests with the access token - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) - - # time passes - self.reactor.advance(24 * 3600) - - # ... and we should be soft-logouted - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) - - # - # test behaviour after deleting the expired device - # - - # we now log in as a different device - access_token_2 = self.login("kermit", "monkey") - - # more requests with the expired token should still return a soft-logout - self.reactor.advance(3600) - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) - - # ... but if we delete that device, it will be a proper logout - self._delete_device(access_token_2, "kermit", "monkey", device_id) - - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], False) - - def _delete_device(self, access_token, user_id, password, device_id): - """Perform the UI-Auth to delete a device""" - channel = self.make_request( - b"DELETE", "devices/" + device_id, access_token=access_token - ) - self.assertEquals(channel.code, 401, channel.result) - # check it's a UI-Auth fail - self.assertEqual( - set(channel.json_body.keys()), - {"flows", "params", "session"}, - channel.result, - ) - - auth = { - "type": "m.login.password", - # https://github.com/matrix-org/synapse/issues/5665 - # "identifier": {"type": "m.id.user", "user": user_id}, - "user": user_id, - "password": password, - "session": channel.json_body["session"], - } - - channel = self.make_request( - b"DELETE", - "devices/" + device_id, - access_token=access_token, - content={"auth": auth}, - ) - self.assertEquals(channel.code, 200, channel.result) - - @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_after_being_soft_logged_out(self): - self.register_user("kermit", "monkey") - - # log in as normal - access_token = self.login("kermit", "monkey") - - # we should now be able to make requests with the access token - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) - - # time passes - self.reactor.advance(24 * 3600) - - # ... and we should be soft-logouted - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) - - # Now try to hard logout this session - channel = self.make_request(b"POST", "/logout", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) - - @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): - self.register_user("kermit", "monkey") - - # log in as normal - access_token = self.login("kermit", "monkey") - - # we should now be able to make requests with the access token - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) - - # time passes - self.reactor.advance(24 * 3600) - - # ... and we should be soft-logouted - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) - - # Now try to hard log out all of the user's sessions - channel = self.make_request(b"POST", "/logout/all", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) - - -@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") -class MultiSSOTestCase(unittest.HomeserverTestCase): - """Tests for homeservers with multiple SSO providers enabled""" - - servlets = [ - login.register_servlets, - ] - - def default_config(self) -> Dict[str, Any]: - config = super().default_config() - - config["public_baseurl"] = BASE_URL - - config["cas_config"] = { - "enabled": True, - "server_url": CAS_SERVER, - "service_url": "https://matrix.goodserver.com:8448", - } - - config["saml2_config"] = { - "sp_config": { - "metadata": {"inline": [TEST_SAML_METADATA]}, - # use the XMLSecurity backend to avoid relying on xmlsec1 - "crypto_backend": "XMLSecurity", - }, - } - - # default OIDC provider - config["oidc_config"] = TEST_OIDC_CONFIG - - # additional OIDC providers - config["oidc_providers"] = [ - { - "idp_id": "idp1", - "idp_name": "IDP1", - "discover": False, - "issuer": "https://issuer1", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "scopes": ["profile"], - "authorization_endpoint": "https://issuer1/auth", - "token_endpoint": "https://issuer1/token", - "userinfo_endpoint": "https://issuer1/userinfo", - "user_mapping_provider": { - "config": {"localpart_template": "{{ user.sub }}"} - }, - } - ] - return config - - def create_resource_dict(self) -> Dict[str, Resource]: - d = super().create_resource_dict() - d.update(build_synapse_client_resource_tree(self.hs)) - return d - - def test_get_login_flows(self): - """GET /login should return password and SSO flows""" - channel = self.make_request("GET", "/_matrix/client/r0/login") - self.assertEqual(channel.code, 200, channel.result) - - expected_flow_types = [ - "m.login.cas", - "m.login.sso", - "m.login.token", - "m.login.password", - ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS] - - self.assertCountEqual( - [f["type"] for f in channel.json_body["flows"]], expected_flow_types - ) - - @override_config({"experimental_features": {"msc2858_enabled": True}}) - def test_get_msc2858_login_flows(self): - """The SSO flow should include IdP info if MSC2858 is enabled""" - channel = self.make_request("GET", "/_matrix/client/r0/login") - self.assertEqual(channel.code, 200, channel.result) - - # stick the flows results in a dict by type - flow_results: Dict[str, Any] = {} - for f in channel.json_body["flows"]: - flow_type = f["type"] - self.assertNotIn( - flow_type, flow_results, "duplicate flow type %s" % (flow_type,) - ) - flow_results[flow_type] = f - - self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned") - sso_flow = flow_results.pop("m.login.sso") - # we should have a set of IdPs - self.assertCountEqual( - sso_flow["org.matrix.msc2858.identity_providers"], - [ - {"id": "cas", "name": "CAS"}, - {"id": "saml", "name": "SAML"}, - {"id": "oidc-idp1", "name": "IDP1"}, - {"id": "oidc", "name": "OIDC"}, - ], - ) - - # the rest of the flows are simple - expected_flows = [ - {"type": "m.login.cas"}, - {"type": "m.login.token"}, - {"type": "m.login.password"}, - ] + ADDITIONAL_LOGIN_FLOWS - - self.assertCountEqual(flow_results.values(), expected_flows) - - def test_multi_sso_redirect(self): - """/login/sso/redirect should redirect to an identity picker""" - # first hit the redirect url, which should redirect to our idp picker - channel = self._make_sso_redirect_request(False, None) - self.assertEqual(channel.code, 302, channel.result) - uri = channel.headers.getRawHeaders("Location")[0] - - # hitting that picker should give us some HTML - channel = self.make_request("GET", uri) - self.assertEqual(channel.code, 200, channel.result) - - # parse the form to check it has fields assumed elsewhere in this class - html = channel.result["body"].decode("utf-8") - p = TestHtmlParser() - p.feed(html) - p.close() - - # there should be a link for each href - returned_idps: List[str] = [] - for link in p.links: - path, query = link.split("?", 1) - self.assertEqual(path, "pick_idp") - params = urllib.parse.parse_qs(query) - self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL]) - returned_idps.append(params["idp"][0]) - - self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) - - def test_multi_sso_redirect_to_cas(self): - """If CAS is chosen, should redirect to the CAS server""" - - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=cas", - shorthand=False, - ) - self.assertEqual(channel.code, 302, channel.result) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - cas_uri = location_headers[0] - cas_uri_path, cas_uri_query = cas_uri.split("?", 1) - - # it should redirect us to the login page of the cas server - self.assertEqual(cas_uri_path, CAS_SERVER + "/login") - - # check that the redirectUrl is correctly encoded in the service param - ie, the - # place that CAS will redirect to - cas_uri_params = urllib.parse.parse_qs(cas_uri_query) - service_uri = cas_uri_params["service"][0] - _, service_uri_query = service_uri.split("?", 1) - service_uri_params = urllib.parse.parse_qs(service_uri_query) - self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) - - def test_multi_sso_redirect_to_saml(self): - """If SAML is chosen, should redirect to the SAML server""" - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=saml", - ) - self.assertEqual(channel.code, 302, channel.result) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - saml_uri = location_headers[0] - saml_uri_path, saml_uri_query = saml_uri.split("?", 1) - - # it should redirect us to the login page of the SAML server - self.assertEqual(saml_uri_path, SAML_SERVER) - - # the RelayState is used to carry the client redirect url - saml_uri_params = urllib.parse.parse_qs(saml_uri_query) - relay_state_param = saml_uri_params["RelayState"][0] - self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - - def test_login_via_oidc(self): - """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - - # pick the default OIDC provider - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=oidc", - ) - self.assertEqual(channel.code, 302, channel.result) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - oidc_uri = location_headers[0] - oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) - - # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - - # ... and should have set a cookie including the redirect url - cookie_headers = channel.headers.getRawHeaders("Set-Cookie") - assert cookie_headers - cookies: Dict[str, str] = {} - for h in cookie_headers: - key, value = h.split(";")[0].split("=", maxsplit=1) - cookies[key] = value - - oidc_session_cookie = cookies["oidc_session"] - macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) - self.assertEqual( - self._get_value_from_macaroon(macaroon, "client_redirect_url"), - TEST_CLIENT_REDIRECT_URL, - ) - - channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) - - # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) - content_type_headers = channel.headers.getRawHeaders("Content-Type") - assert content_type_headers - self.assertTrue(content_type_headers[-1].startswith("text/html")) - p = TestHtmlParser() - p.feed(channel.text_body) - p.close() - - # ... which should contain our redirect link - self.assertEqual(len(p.links), 1) - path, query = p.links[0].split("?", 1) - self.assertEqual(path, "https://x") - - # it will have url-encoded the params properly, so we'll have to parse them - params = urllib.parse.parse_qsl( - query, keep_blank_values=True, strict_parsing=True, errors="strict" - ) - self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) - self.assertEqual(params[2][0], "loginToken") - - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token, mxid, and device id. - login_token = params[2][1] - chan = self.make_request( - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], "@user1:test") - - def test_multi_sso_redirect_to_unknown(self): - """An unknown IdP should cause a 400""" - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", - ) - self.assertEqual(channel.code, 400, channel.result) - - def test_client_idp_redirect_to_unknown(self): - """If the client tries to pick an unknown IdP, return a 404""" - channel = self._make_sso_redirect_request(False, "xxx") - self.assertEqual(channel.code, 404, channel.result) - self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - - def test_client_idp_redirect_to_oidc(self): - """If the client pick a known IdP, redirect to it""" - channel = self._make_sso_redirect_request(False, "oidc") - self.assertEqual(channel.code, 302, channel.result) - oidc_uri = channel.headers.getRawHeaders("Location")[0] - oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) - - # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - - @override_config({"experimental_features": {"msc2858_enabled": True}}) - def test_client_msc2858_redirect_to_oidc(self): - """Test the unstable API""" - channel = self._make_sso_redirect_request(True, "oidc") - self.assertEqual(channel.code, 302, channel.result) - oidc_uri = channel.headers.getRawHeaders("Location")[0] - oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) - - # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - - def test_client_idp_redirect_msc2858_disabled(self): - """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400""" - channel = self._make_sso_redirect_request(True, "oidc") - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") - - def _make_sso_redirect_request( - self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None - ): - """Send a request to /_matrix/client/r0/login/sso/redirect - - ... or the unstable equivalent - - ... possibly specifying an IDP provider - """ - endpoint = ( - "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect" - if unstable_endpoint - else "/_matrix/client/r0/login/sso/redirect" - ) - if idp_prov is not None: - endpoint += "/" + idp_prov - endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - - return self.make_request( - "GET", - endpoint, - custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)], - ) - - @staticmethod - def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: - prefix = key + " = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(prefix): - return caveat.caveat_id[len(prefix) :] - raise ValueError("No %s caveat in macaroon" % (key,)) - - -class CASTestCase(unittest.HomeserverTestCase): - - servlets = [ - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.base_url = "https://matrix.goodserver.com/" - self.redirect_path = "_synapse/client/login/sso/redirect/confirm" - - config = self.default_config() - config["public_baseurl"] = ( - config.get("public_baseurl") or "https://matrix.goodserver.com:8448" - ) - config["cas_config"] = { - "enabled": True, - "server_url": CAS_SERVER, - } - - cas_user_id = "username" - self.user_id = "@%s:test" % cas_user_id - - async def get_raw(uri, args): - """Return an example response payload from a call to the `/proxyValidate` - endpoint of a CAS server, copied from - https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 - - This needs to be returned by an async function (as opposed to set as the - mock's return value) because the corresponding Synapse code awaits on it. - """ - return ( - """ - - - %s - PGTIOU-84678-8a9d... - - https://proxy2/pgtUrl - https://proxy1/pgtUrl - - - - """ - % cas_user_id - ).encode("utf-8") - - mocked_http_client = Mock(spec=["get_raw"]) - mocked_http_client.get_raw.side_effect = get_raw - - self.hs = self.setup_test_homeserver( - config=config, - proxied_http_client=mocked_http_client, - ) - - return self.hs - - def prepare(self, reactor, clock, hs): - self.deactivate_account_handler = hs.get_deactivate_account_handler() - - def test_cas_redirect_confirm(self): - """Tests that the SSO login flow serves a confirmation page before redirecting a - user to the redirect URL. - """ - base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl" - redirect_url = "https://dodgy-site.com/" - - url_parts = list(urllib.parse.urlparse(base_url)) - query = dict(urllib.parse.parse_qsl(url_parts[4])) - query.update({"redirectUrl": redirect_url}) - query.update({"ticket": "ticket"}) - url_parts[4] = urllib.parse.urlencode(query) - cas_ticket_url = urllib.parse.urlunparse(url_parts) - - # Get Synapse to call the fake CAS and serve the template. - channel = self.make_request("GET", cas_ticket_url) - - # Test that the response is HTML. - self.assertEqual(channel.code, 200, channel.result) - content_type_header_value = "" - for header in channel.result.get("headers", []): - if header[0] == b"Content-Type": - content_type_header_value = header[1].decode("utf8") - - self.assertTrue(content_type_header_value.startswith("text/html")) - - # Test that the body isn't empty. - self.assertTrue(len(channel.result["body"]) > 0) - - # And that it contains our redirect link - self.assertIn(redirect_url, channel.result["body"].decode("UTF-8")) - - @override_config( - { - "sso": { - "client_whitelist": [ - "https://legit-site.com/", - "https://other-site.com/", - ] - } - } - ) - def test_cas_redirect_whitelisted(self): - """Tests that the SSO login flow serves a redirect to a whitelisted url""" - self._test_redirect("https://legit-site.com/") - - @override_config({"public_baseurl": "https://example.com"}) - def test_cas_redirect_login_fallback(self): - self._test_redirect("https://example.com/_matrix/static/client/login") - - def _test_redirect(self, redirect_url): - """Tests that the SSO login flow serves a redirect for the given redirect URL.""" - cas_ticket_url = ( - "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" - % (urllib.parse.quote(redirect_url)) - ) - - # Get Synapse to call the fake CAS and serve the template. - channel = self.make_request("GET", cas_ticket_url) - - self.assertEqual(channel.code, 302) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) - - @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) - def test_deactivated_user(self): - """Logging in as a deactivated account should error.""" - redirect_url = "https://legit-site.com/" - - # First login (to create the user). - self._test_redirect(redirect_url) - - # Deactivate the account. - self.get_success( - self.deactivate_account_handler.deactivate_account( - self.user_id, False, create_requester(self.user_id) - ) - ) - - # Request the CAS ticket. - cas_ticket_url = ( - "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" - % (urllib.parse.quote(redirect_url)) - ) - - # Get Synapse to call the fake CAS and serve the template. - channel = self.make_request("GET", cas_ticket_url) - - # Because the user is deactivated they are served an error template. - self.assertEqual(channel.code, 403) - self.assertIn(b"SSO account deactivated", channel.result["body"]) - - -@skip_unless(HAS_JWT, "requires jwt") -class JWTTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - ] - - jwt_secret = "secret" - jwt_algorithm = "HS256" - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt_enabled = True - self.hs.config.jwt_secret = self.jwt_secret - self.hs.config.jwt_algorithm = self.jwt_algorithm - return self.hs - - def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) - if isinstance(result, bytes): - return result.decode("ascii") - return result - - def jwt_login(self, *args): - params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - channel = self.make_request(b"POST", LOGIN_URL, params) - return channel - - def test_login_jwt_valid_registered(self): - self.register_user("kermit", "monkey") - channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - def test_login_jwt_valid_unregistered(self): - channel = self.jwt_login({"sub": "frog"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@frog:test") - - def test_login_jwt_invalid_signature(self): - channel = self.jwt_login({"sub": "frog"}, "notsecret") - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - "JWT validation failed: Signature verification failed", - ) - - def test_login_jwt_expired(self): - channel = self.jwt_login({"sub": "frog", "exp": 864000}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], "JWT validation failed: Signature has expired" - ) - - def test_login_jwt_not_before(self): - now = int(time.time()) - channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - "JWT validation failed: The token is not yet valid (nbf)", - ) - - def test_login_no_sub(self): - channel = self.jwt_login({"username": "root"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual(channel.json_body["error"], "Invalid JWT") - - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "issuer": "test-issuer", - } - } - ) - def test_login_iss(self): - """Test validating the issuer claim.""" - # A valid issuer. - channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - # An invalid issuer. - channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid issuer" - ) - - # Not providing an issuer. - channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - 'JWT validation failed: Token is missing the "iss" claim', - ) - - def test_login_iss_no_config(self): - """Test providing an issuer claim without requiring it in the configuration.""" - channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "audiences": ["test-audience"], - } - } - ) - def test_login_aud(self): - """Test validating the audience claim.""" - # A valid audience. - channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - # An invalid audience. - channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" - ) - - # Not providing an audience. - channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - 'JWT validation failed: Token is missing the "aud" claim', - ) - - def test_login_aud_no_config(self): - """Test providing an audience without requiring it in the configuration.""" - channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" - ) - - def test_login_no_token(self): - params = {"type": "org.matrix.login.jwt"} - channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") - - -# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use -# RSS256, with a public key configured in synapse as "jwt_secret", and tokens -# signed by the private key. -@skip_unless(HAS_JWT, "requires jwt") -class JWTPubKeyTestCase(unittest.HomeserverTestCase): - servlets = [ - login.register_servlets, - ] - - # This key's pubkey is used as the jwt_secret setting of synapse. Valid - # tokens are signed by this and validated using the pubkey. It is generated - # with `openssl genrsa 512` (not a secure way to generate real keys, but - # good enough for tests!) - jwt_privatekey = "\n".join( - [ - "-----BEGIN RSA PRIVATE KEY-----", - "MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB", - "492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk", - "yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/", - "kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq", - "TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN", - "ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA", - "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=", - "-----END RSA PRIVATE KEY-----", - ] - ) - - # Generated with `openssl rsa -in foo.key -pubout`, with the the above - # private key placed in foo.key (jwt_privatekey). - jwt_pubkey = "\n".join( - [ - "-----BEGIN PUBLIC KEY-----", - "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7", - "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==", - "-----END PUBLIC KEY-----", - ] - ) - - # This key is used to sign tokens that shouldn't be accepted by synapse. - # Generated just like jwt_privatekey. - bad_privatekey = "\n".join( - [ - "-----BEGIN RSA PRIVATE KEY-----", - "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv", - "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L", - "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY", - "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I", - "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb", - "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0", - "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m", - "-----END RSA PRIVATE KEY-----", - ] - ) - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt_enabled = True - self.hs.config.jwt_secret = self.jwt_pubkey - self.hs.config.jwt_algorithm = "RS256" - return self.hs - - def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") - if isinstance(result, bytes): - return result.decode("ascii") - return result - - def jwt_login(self, *args): - params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - channel = self.make_request(b"POST", LOGIN_URL, params) - return channel - - def test_login_jwt_valid(self): - channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - def test_login_jwt_invalid_signature(self): - channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - "JWT validation failed: Signature verification failed", - ) - - -AS_USER = "as_user_alice" - - -class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): - servlets = [ - login.register_servlets, - register.register_servlets, - ] - - def register_as_user(self, username): - self.make_request( - b"POST", - "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), - {"username": username}, - ) - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - - self.service = ApplicationService( - id="unique_identifier", - token="some_token", - hostname="example.com", - sender="@asbot:example.com", - namespaces={ - ApplicationService.NS_USERS: [ - {"regex": r"@as_user.*", "exclusive": False} - ], - ApplicationService.NS_ROOMS: [], - ApplicationService.NS_ALIASES: [], - }, - ) - self.another_service = ApplicationService( - id="another__identifier", - token="another_token", - hostname="example.com", - sender="@as2bot:example.com", - namespaces={ - ApplicationService.NS_USERS: [ - {"regex": r"@as2_user.*", "exclusive": False} - ], - ApplicationService.NS_ROOMS: [], - ApplicationService.NS_ALIASES: [], - }, - ) - - self.hs.get_datastore().services_cache.append(self.service) - self.hs.get_datastore().services_cache.append(self.another_service) - return self.hs - - def test_login_appservice_user(self): - """Test that an appservice user can use /login""" - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": AS_USER}, - } - channel = self.make_request( - b"POST", LOGIN_URL, params, access_token=self.service.token - ) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_login_appservice_user_bot(self): - """Test that the appservice bot can use /login""" - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": self.service.sender}, - } - channel = self.make_request( - b"POST", LOGIN_URL, params, access_token=self.service.token - ) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_login_appservice_wrong_user(self): - """Test that non-as users cannot login with the as token""" - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": "fibble_wibble"}, - } - channel = self.make_request( - b"POST", LOGIN_URL, params, access_token=self.service.token - ) - - self.assertEquals(channel.result["code"], b"403", channel.result) - - def test_login_appservice_wrong_as(self): - """Test that as users cannot login with wrong as token""" - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": AS_USER}, - } - channel = self.make_request( - b"POST", LOGIN_URL, params, access_token=self.another_service.token - ) - - self.assertEquals(channel.result["code"], b"403", channel.result) - - def test_login_appservice_no_token(self): - """Test that users must provide a token when using the appservice - login method - """ - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": AS_USER}, - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.result["code"], b"401", channel.result) - - -@skip_unless(HAS_OIDC, "requires OIDC") -class UsernamePickerTestCase(HomeserverTestCase): - """Tests for the username picker flow of SSO login""" - - servlets = [login.register_servlets] - - def default_config(self): - config = super().default_config() - config["public_baseurl"] = BASE_URL - - config["oidc_config"] = {} - config["oidc_config"].update(TEST_OIDC_CONFIG) - config["oidc_config"]["user_mapping_provider"] = { - "config": {"display_name_template": "{{ user.displayname }}"} - } - - # whitelist this client URI so we redirect straight to it rather than - # serving a confirmation page - config["sso"] = {"client_whitelist": ["https://x"]} - return config - - def create_resource_dict(self) -> Dict[str, Resource]: - d = super().create_resource_dict() - d.update(build_synapse_client_resource_tree(self.hs)) - return d - - def test_username_picker(self): - """Test the happy path of a username picker flow.""" - - # do the start of the login flow - channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL - ) - - # that should redirect to the username picker - self.assertEqual(channel.code, 302, channel.result) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - picker_url = location_headers[0] - self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") - - # ... with a username_mapping_session cookie - cookies: Dict[str, str] = {} - channel.extract_cookies(cookies) - self.assertIn("username_mapping_session", cookies) - session_id = cookies["username_mapping_session"] - - # introspect the sso handler a bit to check that the username mapping session - # looks ok. - username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions - self.assertIn( - session_id, - username_mapping_sessions, - "session id not found in map", - ) - session = username_mapping_sessions[session_id] - self.assertEqual(session.remote_user_id, "tester") - self.assertEqual(session.display_name, "Jonny") - self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) - - # the expiry time should be about 15 minutes away - expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) - self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) - - # Now, submit a username to the username picker, which should serve a redirect - # to the completion page - content = urlencode({b"username": b"bobby"}).encode("utf8") - chan = self.make_request( - "POST", - path=picker_url, - content=content, - content_is_form=True, - custom_headers=[ - ("Cookie", "username_mapping_session=" + session_id), - # old versions of twisted don't do form-parsing without a valid - # content-length header. - ("Content-Length", str(len(content))), - ], - ) - self.assertEqual(chan.code, 302, chan.result) - location_headers = chan.headers.getRawHeaders("Location") - assert location_headers - - # send a request to the completion page, which should 302 to the client redirectUrl - chan = self.make_request( - "GET", - path=location_headers[0], - custom_headers=[("Cookie", "username_mapping_session=" + session_id)], - ) - self.assertEqual(chan.code, 302, chan.result) - location_headers = chan.headers.getRawHeaders("Location") - assert location_headers - - # ensure that the returned location matches the requested redirect URL - path, query = location_headers[0].split("?", 1) - self.assertEqual(path, "https://x") - - # it will have url-encoded the params properly, so we'll have to parse them - params = urllib.parse.parse_qsl( - query, keep_blank_values=True, strict_parsing=True, errors="strict" - ) - self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) - self.assertEqual(params[2][0], "loginToken") - - # fish the login token out of the returned redirect uri - login_token = params[2][1] - - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token, mxid, and device id. - chan = self.make_request( - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py deleted file mode 100644 index 1d152352d1..0000000000 --- a/tests/rest/client/v1/test_presence.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest.mock import Mock - -from twisted.internet import defer - -from synapse.handlers.presence import PresenceHandler -from synapse.rest.client import presence -from synapse.types import UserID - -from tests import unittest - - -class PresenceTestCase(unittest.HomeserverTestCase): - """Tests presence REST API.""" - - user_id = "@sid:red" - - user = UserID.from_string(user_id) - servlets = [presence.register_servlets] - - def make_homeserver(self, reactor, clock): - - presence_handler = Mock(spec=PresenceHandler) - presence_handler.set_state.return_value = defer.succeed(None) - - hs = self.setup_test_homeserver( - "red", - federation_http_client=None, - federation_client=Mock(), - presence_handler=presence_handler, - ) - - return hs - - def test_put_presence(self): - """ - PUT to the status endpoint with use_presence enabled will call - set_state on the presence handler. - """ - self.hs.config.use_presence = True - - body = {"presence": "here", "status_msg": "beep boop"} - channel = self.make_request( - "PUT", "/presence/%s/status" % (self.user_id,), body - ) - - self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) - - @unittest.override_config({"use_presence": False}) - def test_put_presence_disabled(self): - """ - PUT to the status endpoint with use_presence disabled will NOT call - set_state on the presence handler. - """ - - body = {"presence": "here", "status_msg": "beep boop"} - channel = self.make_request( - "PUT", "/presence/%s/status" % (self.user_id,), body - ) - - self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py deleted file mode 100644 index 2860579c2e..0000000000 --- a/tests/rest/client/v1/test_profile.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests REST events for /profile paths.""" -from synapse.rest import admin -from synapse.rest.client import login, profile, room - -from tests import unittest - - -class ProfileTestCase(unittest.HomeserverTestCase): - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - profile.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - return self.hs - - def prepare(self, reactor, clock, hs): - self.owner = self.register_user("owner", "pass") - self.owner_tok = self.login("owner", "pass") - self.other = self.register_user("other", "pass", displayname="Bob") - - def test_get_displayname(self): - res = self._get_displayname() - self.assertEqual(res, "owner") - - def test_set_displayname(self): - channel = self.make_request( - "PUT", - "/profile/%s/displayname" % (self.owner,), - content={"displayname": "test"}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - res = self._get_displayname() - self.assertEqual(res, "test") - - def test_set_displayname_noauth(self): - channel = self.make_request( - "PUT", - "/profile/%s/displayname" % (self.owner,), - content={"displayname": "test"}, - ) - self.assertEqual(channel.code, 401, channel.result) - - def test_set_displayname_too_long(self): - """Attempts to set a stupid displayname should get a 400""" - channel = self.make_request( - "PUT", - "/profile/%s/displayname" % (self.owner,), - content={"displayname": "test" * 100}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - - res = self._get_displayname() - self.assertEqual(res, "owner") - - def test_get_displayname_other(self): - res = self._get_displayname(self.other) - self.assertEquals(res, "Bob") - - def test_set_displayname_other(self): - channel = self.make_request( - "PUT", - "/profile/%s/displayname" % (self.other,), - content={"displayname": "test"}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - - def test_get_avatar_url(self): - res = self._get_avatar_url() - self.assertIsNone(res) - - def test_set_avatar_url(self): - channel = self.make_request( - "PUT", - "/profile/%s/avatar_url" % (self.owner,), - content={"avatar_url": "http://my.server/pic.gif"}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - res = self._get_avatar_url() - self.assertEqual(res, "http://my.server/pic.gif") - - def test_set_avatar_url_noauth(self): - channel = self.make_request( - "PUT", - "/profile/%s/avatar_url" % (self.owner,), - content={"avatar_url": "http://my.server/pic.gif"}, - ) - self.assertEqual(channel.code, 401, channel.result) - - def test_set_avatar_url_too_long(self): - """Attempts to set a stupid avatar_url should get a 400""" - channel = self.make_request( - "PUT", - "/profile/%s/avatar_url" % (self.owner,), - content={"avatar_url": "http://my.server/pic.gif" * 100}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - - res = self._get_avatar_url() - self.assertIsNone(res) - - def test_get_avatar_url_other(self): - res = self._get_avatar_url(self.other) - self.assertIsNone(res) - - def test_set_avatar_url_other(self): - channel = self.make_request( - "PUT", - "/profile/%s/avatar_url" % (self.other,), - content={"avatar_url": "http://my.server/pic.gif"}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - - def _get_displayname(self, name=None): - channel = self.make_request( - "GET", "/profile/%s/displayname" % (name or self.owner,) - ) - self.assertEqual(channel.code, 200, channel.result) - return channel.json_body["displayname"] - - def _get_avatar_url(self, name=None): - channel = self.make_request( - "GET", "/profile/%s/avatar_url" % (name or self.owner,) - ) - self.assertEqual(channel.code, 200, channel.result) - return channel.json_body.get("avatar_url") - - -class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - profile.register_servlets, - room.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - - config = self.default_config() - config["require_auth_for_profile_requests"] = True - config["limit_profile_requests_to_users_who_share_rooms"] = True - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def prepare(self, reactor, clock, hs): - # User owning the requested profile. - self.owner = self.register_user("owner", "pass") - self.owner_tok = self.login("owner", "pass") - self.profile_url = "/profile/%s" % (self.owner) - - # User requesting the profile. - self.requester = self.register_user("requester", "pass") - self.requester_tok = self.login("requester", "pass") - - self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok) - - def test_no_auth(self): - self.try_fetch_profile(401) - - def test_not_in_shared_room(self): - self.ensure_requester_left_room() - - self.try_fetch_profile(403, access_token=self.requester_tok) - - def test_in_shared_room(self): - self.ensure_requester_left_room() - - self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) - - self.try_fetch_profile(200, self.requester_tok) - - def try_fetch_profile(self, expected_code, access_token=None): - self.request_profile(expected_code, access_token=access_token) - - self.request_profile( - expected_code, url_suffix="/displayname", access_token=access_token - ) - - self.request_profile( - expected_code, url_suffix="/avatar_url", access_token=access_token - ) - - def request_profile(self, expected_code, url_suffix="", access_token=None): - channel = self.make_request( - "GET", self.profile_url + url_suffix, access_token=access_token - ) - self.assertEqual(channel.code, expected_code, channel.result) - - def ensure_requester_left_room(self): - try: - self.helper.leave( - room=self.room_id, user=self.requester, tok=self.requester_tok - ) - except AssertionError: - # We don't care whether the leave request didn't return a 200 (e.g. - # if the user isn't already in the room), because we only want to - # make sure the user isn't in the room. - pass - - -class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - profile.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["require_auth_for_profile_requests"] = True - config["limit_profile_requests_to_users_who_share_rooms"] = True - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def prepare(self, reactor, clock, hs): - # User requesting the profile. - self.requester = self.register_user("requester", "pass") - self.requester_tok = self.login("requester", "pass") - - def test_can_lookup_own_profile(self): - """Tests that a user can lookup their own profile without having to be in a room - if 'require_auth_for_profile_requests' is set to true in the server's config. - """ - channel = self.make_request( - "GET", "/profile/" + self.requester, access_token=self.requester_tok - ) - self.assertEqual(channel.code, 200, channel.result) - - channel = self.make_request( - "GET", - "/profile/" + self.requester + "/displayname", - access_token=self.requester_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - channel = self.make_request( - "GET", - "/profile/" + self.requester + "/avatar_url", - access_token=self.requester_tok, - ) - self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py deleted file mode 100644 index d0ce91ccd9..0000000000 --- a/tests/rest/client/v1/test_push_rule_attrs.py +++ /dev/null @@ -1,414 +0,0 @@ -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import synapse -from synapse.api.errors import Codes -from synapse.rest.client import login, push_rule, room - -from tests.unittest import HomeserverTestCase - - -class PushRuleAttributesTestCase(HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - push_rule.register_servlets, - ] - hijack_auth = False - - def test_enabled_on_creation(self): - """ - Tests the GET and PUT of push rules' `enabled` endpoints. - Tests that a rule is enabled upon creation, even though a rule with that - ruleId existed previously and was disabled. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # GET enabled for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], True) - - def test_enabled_on_recreation(self): - """ - Tests the GET and PUT of push rules' `enabled` endpoints. - Tests that a rule is enabled upon creation, even if a rule with that - ruleId existed previously and was disabled. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # disable the rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/enabled", - {"enabled": False}, - access_token=token, - ) - self.assertEqual(channel.code, 200) - - # check rule disabled - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], False) - - # DELETE the rule - channel = self.make_request( - "DELETE", "/pushrules/global/override/best.friend", access_token=token - ) - self.assertEqual(channel.code, 200) - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # GET enabled for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], True) - - def test_enabled_disable(self): - """ - Tests the GET and PUT of push rules' `enabled` endpoints. - Tests that a rule is disabled and enabled when we ask for it. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # disable the rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/enabled", - {"enabled": False}, - access_token=token, - ) - self.assertEqual(channel.code, 200) - - # check rule disabled - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], False) - - # re-enable the rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/enabled", - {"enabled": True}, - access_token=token, - ) - self.assertEqual(channel.code, 200) - - # check rule enabled - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], True) - - def test_enabled_404_when_get_non_existent(self): - """ - Tests that `enabled` gives 404 when the rule doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # check 404 for never-heard-of rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # GET enabled for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - - # DELETE the rule - channel = self.make_request( - "DELETE", "/pushrules/global/override/best.friend", access_token=token - ) - self.assertEqual(channel.code, 200) - - # check 404 for deleted rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_enabled_404_when_get_non_existent_server_rule(self): - """ - Tests that `enabled` gives 404 when the server-default rule doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # check 404 for never-heard-of rule - channel = self.make_request( - "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_enabled_404_when_put_non_existent_rule(self): - """ - Tests that `enabled` gives 404 when we put to a rule that doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # enable & check 404 for never-heard-of rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/enabled", - {"enabled": True}, - access_token=token, - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_enabled_404_when_put_non_existent_server_rule(self): - """ - Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # enable & check 404 for never-heard-of rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/.m.muahahah/enabled", - {"enabled": True}, - access_token=token, - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_actions_get(self): - """ - Tests that `actions` gives you what you expect on a fresh rule. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # GET actions for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/actions", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] - ) - - def test_actions_put(self): - """ - Tests that PUT on actions updates the value you'd get from GET. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # change the rule actions - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/actions", - {"actions": ["dont_notify"]}, - access_token=token, - ) - self.assertEqual(channel.code, 200) - - # GET actions for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/actions", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["actions"], ["dont_notify"]) - - def test_actions_404_when_get_non_existent(self): - """ - Tests that `actions` gives 404 when the rule doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # check 404 for never-heard-of rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # DELETE the rule - channel = self.make_request( - "DELETE", "/pushrules/global/override/best.friend", access_token=token - ) - self.assertEqual(channel.code, 200) - - # check 404 for deleted rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_actions_404_when_get_non_existent_server_rule(self): - """ - Tests that `actions` gives 404 when the server-default rule doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # check 404 for never-heard-of rule - channel = self.make_request( - "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_actions_404_when_put_non_existent_rule(self): - """ - Tests that `actions` gives 404 when putting to a rule that doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # enable & check 404 for never-heard-of rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/actions", - {"actions": ["dont_notify"]}, - access_token=token, - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_actions_404_when_put_non_existent_server_rule(self): - """ - Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # enable & check 404 for never-heard-of rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/.m.muahahah/actions", - {"actions": ["dont_notify"]}, - access_token=token, - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py deleted file mode 100644 index 0c9cbb9aff..0000000000 --- a/tests/rest/client/v1/test_rooms.py +++ /dev/null @@ -1,2150 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests REST events for /rooms paths.""" - -import json -from typing import Iterable -from unittest.mock import Mock, call -from urllib import parse as urlparse - -from twisted.internet import defer - -import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes, Membership -from synapse.api.errors import HttpResponseException -from synapse.handlers.pagination import PurgeStatus -from synapse.rest import admin -from synapse.rest.client import account, directory, login, profile, room -from synapse.types import JsonDict, RoomAlias, UserID, create_requester -from synapse.util.stringutils import random_string - -from tests import unittest -from tests.test_utils import make_awaitable - -PATH_PREFIX = b"/_matrix/client/api/v1" - - -class RoomBase(unittest.HomeserverTestCase): - rmcreator_id = None - - servlets = [room.register_servlets, room.register_deprecated_servlets] - - def make_homeserver(self, reactor, clock): - - self.hs = self.setup_test_homeserver( - "red", - federation_http_client=None, - federation_client=Mock(), - ) - - self.hs.get_federation_handler = Mock() - self.hs.get_federation_handler.return_value.maybe_backfill = Mock( - return_value=make_awaitable(None) - ) - - async def _insert_client_ip(*args, **kwargs): - return None - - self.hs.get_datastore().insert_client_ip = _insert_client_ip - - return self.hs - - -class RoomPermissionsTestCase(RoomBase): - """Tests room permissions.""" - - user_id = "@sid1:red" - rmcreator_id = "@notme:red" - - def prepare(self, reactor, clock, hs): - - self.helper.auth_user_id = self.rmcreator_id - # create some rooms under the name rmcreator_id - self.uncreated_rmid = "!aa:test" - self.created_rmid = self.helper.create_room_as( - self.rmcreator_id, is_public=False - ) - self.created_public_rmid = self.helper.create_room_as( - self.rmcreator_id, is_public=True - ) - - # send a message in one of the rooms - self.created_rmid_msg_path = ( - "rooms/%s/send/m.room.message/a1" % (self.created_rmid) - ).encode("ascii") - channel = self.make_request( - "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' - ) - self.assertEquals(200, channel.code, channel.result) - - # set topic for public room - channel = self.make_request( - "PUT", - ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), - b'{"topic":"Public Room Topic"}', - ) - self.assertEquals(200, channel.code, channel.result) - - # auth as user_id now - self.helper.auth_user_id = self.user_id - - def test_can_do_action(self): - msg_content = b'{"msgtype":"m.text","body":"hello"}' - - seq = iter(range(100)) - - def send_msg_path(): - return "/rooms/%s/send/m.room.message/mid%s" % ( - self.created_rmid, - str(next(seq)), - ) - - # send message in uncreated room, expect 403 - channel = self.make_request( - "PUT", - "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), - msg_content, - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # send message in created room not joined (no state), expect 403 - channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # send message in created room and invited, expect 403 - self.helper.invite( - room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id - ) - channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # send message in created room and joined, expect 200 - self.helper.join(room=self.created_rmid, user=self.user_id) - channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # send message in created room and left, expect 403 - self.helper.leave(room=self.created_rmid, user=self.user_id) - channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - def test_topic_perms(self): - topic_content = b'{"topic":"My Topic Name"}' - topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid - - # set/get topic in uncreated room, expect 403 - channel = self.make_request( - "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - channel = self.make_request( - "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # set/get topic in created PRIVATE room not joined, expect 403 - channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # set topic in created PRIVATE room and invited, expect 403 - self.helper.invite( - room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id - ) - channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # get topic in created PRIVATE room and invited, expect 403 - channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # set/get topic in created PRIVATE room and joined, expect 200 - self.helper.join(room=self.created_rmid, user=self.user_id) - - # Only room ops can set topic by default - self.helper.auth_user_id = self.rmcreator_id - channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.helper.auth_user_id = self.user_id - - channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) - - # set/get topic in created PRIVATE room and left, expect 403 - self.helper.leave(room=self.created_rmid, user=self.user_id) - channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # get topic in PUBLIC room, not joined, expect 403 - channel = self.make_request( - "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # set topic in PUBLIC room, not joined, expect 403 - channel = self.make_request( - "PUT", - "/rooms/%s/state/m.room.topic" % self.created_public_rmid, - topic_content, - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - def _test_get_membership( - self, room=None, members: Iterable = frozenset(), expect_code=None - ): - for member in members: - path = "/rooms/%s/state/m.room.member/%s" % (room, member) - channel = self.make_request("GET", path) - self.assertEquals(expect_code, channel.code) - - def test_membership_basic_room_perms(self): - # === room does not exist === - room = self.uncreated_rmid - # get membership of self, get membership of other, uncreated room - # expect all 403s - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 - ) - - # trying to invite people to this room should 403 - self.helper.invite( - room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 - ) - - # set [invite/join/left] of self, set [invite/join/left] of other, - # expect all 404s because room doesn't exist on any server - for usr in [self.user_id, self.rmcreator_id]: - self.helper.join(room=room, user=usr, expect_code=404) - self.helper.leave(room=room, user=usr, expect_code=404) - - def test_membership_private_room_perms(self): - room = self.created_rmid - # get membership of self, get membership of other, private room + invite - # expect all 403s - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 - ) - - # get membership of self, get membership of other, private room + joined - # expect all 200s - self.helper.join(room=room, user=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 - ) - - # get membership of self, get membership of other, private room + left - # expect all 200s - self.helper.leave(room=room, user=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 - ) - - def test_membership_public_room_perms(self): - room = self.created_public_rmid - # get membership of self, get membership of other, public room + invite - # expect 403 - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 - ) - - # get membership of self, get membership of other, public room + joined - # expect all 200s - self.helper.join(room=room, user=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 - ) - - # get membership of self, get membership of other, public room + left - # expect 200. - self.helper.leave(room=room, user=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 - ) - - def test_invited_permissions(self): - room = self.created_rmid - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - - # set [invite/join/left] of other user, expect 403s - self.helper.invite( - room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 - ) - self.helper.change_membership( - room=room, - src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.JOIN, - expect_code=403, - ) - self.helper.change_membership( - room=room, - src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.LEAVE, - expect_code=403, - ) - - def test_joined_permissions(self): - room = self.created_rmid - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - self.helper.join(room=room, user=self.user_id) - - # set invited of self, expect 403 - self.helper.invite( - room=room, src=self.user_id, targ=self.user_id, expect_code=403 - ) - - # set joined of self, expect 200 (NOOP) - self.helper.join(room=room, user=self.user_id) - - other = "@burgundy:red" - # set invited of other, expect 200 - self.helper.invite(room=room, src=self.user_id, targ=other, expect_code=200) - - # set joined of other, expect 403 - self.helper.change_membership( - room=room, - src=self.user_id, - targ=other, - membership=Membership.JOIN, - expect_code=403, - ) - - # set left of other, expect 403 - self.helper.change_membership( - room=room, - src=self.user_id, - targ=other, - membership=Membership.LEAVE, - expect_code=403, - ) - - # set left of self, expect 200 - self.helper.leave(room=room, user=self.user_id) - - def test_leave_permissions(self): - room = self.created_rmid - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - self.helper.join(room=room, user=self.user_id) - self.helper.leave(room=room, user=self.user_id) - - # set [invite/join/left] of self, set [invite/join/left] of other, - # expect all 403s - for usr in [self.user_id, self.rmcreator_id]: - self.helper.change_membership( - room=room, - src=self.user_id, - targ=usr, - membership=Membership.INVITE, - expect_code=403, - ) - - self.helper.change_membership( - room=room, - src=self.user_id, - targ=usr, - membership=Membership.JOIN, - expect_code=403, - ) - - # It is always valid to LEAVE if you've already left (currently.) - self.helper.change_membership( - room=room, - src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.LEAVE, - expect_code=403, - ) - - -class RoomsMemberListTestCase(RoomBase): - """Tests /rooms/$room_id/members/list REST events.""" - - user_id = "@sid1:red" - - def test_get_member_list(self): - room_id = self.helper.create_room_as(self.user_id) - channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - def test_get_member_list_no_room(self): - channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - def test_get_member_list_no_permission(self): - room_id = self.helper.create_room_as("@some_other_guy:red") - channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - def test_get_member_list_mixed_memberships(self): - room_creator = "@some_other_guy:red" - room_id = self.helper.create_room_as(room_creator) - room_path = "/rooms/%s/members" % room_id - self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) - # can't see list if you're just invited. - channel = self.make_request("GET", room_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - self.helper.join(room=room_id, user=self.user_id) - # can see list now joined - channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - self.helper.leave(room=room_id, user=self.user_id) - # can see old list once left - channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - -class RoomsCreateTestCase(RoomBase): - """Tests /rooms and /rooms/$room_id REST events.""" - - user_id = "@sid1:red" - - def test_post_room_no_keys(self): - # POST with no config keys, expect new room id - channel = self.make_request("POST", "/createRoom", "{}") - - self.assertEquals(200, channel.code, channel.result) - self.assertTrue("room_id" in channel.json_body) - - def test_post_room_visibility_key(self): - # POST with visibility config key, expect new room id - channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') - self.assertEquals(200, channel.code) - self.assertTrue("room_id" in channel.json_body) - - def test_post_room_custom_key(self): - # POST with custom config keys, expect new room id - channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') - self.assertEquals(200, channel.code) - self.assertTrue("room_id" in channel.json_body) - - def test_post_room_known_and_unknown_keys(self): - # POST with custom + known config keys, expect new room id - channel = self.make_request( - "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' - ) - self.assertEquals(200, channel.code) - self.assertTrue("room_id" in channel.json_body) - - def test_post_room_invalid_content(self): - # POST with invalid content / paths, expect 400 - channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.assertEquals(400, channel.code) - - channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.assertEquals(400, channel.code) - - def test_post_room_invitees_invalid_mxid(self): - # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 - # Note the trailing space in the MXID here! - channel = self.make_request( - "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' - ) - self.assertEquals(400, channel.code) - - @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) - def test_post_room_invitees_ratelimit(self): - """Test that invites sent when creating a room are ratelimited by a RateLimiter, - which ratelimits them correctly, including by not limiting when the requester is - exempt from ratelimiting. - """ - - # Build the request's content. We use local MXIDs because invites over federation - # are more difficult to mock. - content = json.dumps( - { - "invite": [ - "@alice1:red", - "@alice2:red", - "@alice3:red", - "@alice4:red", - ] - } - ).encode("utf8") - - # Test that the invites are correctly ratelimited. - channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(400, channel.code) - self.assertEqual( - "Cannot invite so many users at once", - channel.json_body["error"], - ) - - # Add the current user to the ratelimit overrides, allowing them no ratelimiting. - self.get_success( - self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0) - ) - - # Test that the invites aren't ratelimited anymore. - channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(200, channel.code) - - -class RoomTopicTestCase(RoomBase): - """Tests /rooms/$room_id/topic REST events.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - # create the room - self.room_id = self.helper.create_room_as(self.user_id) - self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) - - def test_invalid_puts(self): - # missing keys or invalid json - channel = self.make_request("PUT", self.path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", self.path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request( - "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' - ) - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", self.path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", self.path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - # valid key, wrong type - content = '{"topic":["Topic name"]}' - channel = self.make_request("PUT", self.path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - def test_rooms_topic(self): - # nothing should be there - channel = self.make_request("GET", self.path) - self.assertEquals(404, channel.code, msg=channel.result["body"]) - - # valid put - content = '{"topic":"Topic name"}' - channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # valid get - channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assert_dict(json.loads(content), channel.json_body) - - def test_rooms_topic_with_extra_keys(self): - # valid put with extra keys - content = '{"topic":"Seasons","subtopic":"Summer"}' - channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # valid get - channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assert_dict(json.loads(content), channel.json_body) - - -class RoomMemberStateTestCase(RoomBase): - """Tests /rooms/$room_id/members/$user_id/state REST events.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - self.room_id = self.helper.create_room_as(self.user_id) - - def test_invalid_puts(self): - path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) - # missing keys or invalid json - channel = self.make_request("PUT", path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - # valid keys, wrong types - content = '{"membership":["%s","%s","%s"]}' % ( - Membership.INVITE, - Membership.JOIN, - Membership.LEAVE, - ) - channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - def test_rooms_members_self(self): - path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), - self.user_id, - ) - - # valid join message (NOOP since we made the room) - content = '{"membership":"%s"}' % Membership.JOIN - channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - expected_response = {"membership": Membership.JOIN} - self.assertEquals(expected_response, channel.json_body) - - def test_rooms_members_other(self): - self.other_id = "@zzsid1:red" - path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), - self.other_id, - ) - - # valid invite message - content = '{"membership":"%s"}' % Membership.INVITE - channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) - - def test_rooms_members_other_custom_keys(self): - self.other_id = "@zzsid1:red" - path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), - self.other_id, - ) - - # valid invite message with custom key - content = '{"membership":"%s","invite_text":"%s"}' % ( - Membership.INVITE, - "Join us!", - ) - channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) - - -class RoomInviteRatelimitTestCase(RoomBase): - user_id = "@sid1:red" - - servlets = [ - admin.register_servlets, - profile.register_servlets, - room.register_servlets, - ] - - @unittest.override_config( - {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_invites_by_rooms_ratelimit(self): - """Tests that invites in a room are actually rate-limited.""" - room_id = self.helper.create_room_as(self.user_id) - - for i in range(3): - self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,)) - - self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429) - - @unittest.override_config( - {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_invites_by_users_ratelimit(self): - """Tests that invites to a specific user are actually rate-limited.""" - - for _ in range(3): - room_id = self.helper.create_room_as(self.user_id) - self.helper.invite(room_id, self.user_id, "@other-users:red") - - room_id = self.helper.create_room_as(self.user_id) - self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429) - - -class RoomJoinRatelimitTestCase(RoomBase): - user_id = "@sid1:red" - - servlets = [ - admin.register_servlets, - profile.register_servlets, - room.register_servlets, - ] - - @unittest.override_config( - {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_join_local_ratelimit(self): - """Tests that local joins are actually rate-limited.""" - for _ in range(3): - self.helper.create_room_as(self.user_id) - - self.helper.create_room_as(self.user_id, expect_code=429) - - @unittest.override_config( - {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_join_local_ratelimit_profile_change(self): - """Tests that sending a profile update into all of the user's joined rooms isn't - rate-limited by the rate-limiter on joins.""" - - # Create and join as many rooms as the rate-limiting config allows in a second. - room_ids = [ - self.helper.create_room_as(self.user_id), - self.helper.create_room_as(self.user_id), - self.helper.create_room_as(self.user_id), - ] - # Let some time for the rate-limiter to forget about our multi-join. - self.reactor.advance(2) - # Add one to make sure we're joined to more rooms than the config allows us to - # join in a second. - room_ids.append(self.helper.create_room_as(self.user_id)) - - # Create a profile for the user, since it hasn't been done on registration. - store = self.hs.get_datastore() - self.get_success( - store.create_profile(UserID.from_string(self.user_id).localpart) - ) - - # Update the display name for the user. - path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id - channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.assertEquals(channel.code, 200, channel.json_body) - - # Check that all the rooms have been sent a profile update into. - for room_id in room_ids: - path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( - room_id, - self.user_id, - ) - - channel = self.make_request("GET", path) - self.assertEquals(channel.code, 200) - - self.assertIn("displayname", channel.json_body) - self.assertEquals(channel.json_body["displayname"], "John Doe") - - @unittest.override_config( - {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_join_local_ratelimit_idempotent(self): - """Tests that the room join endpoints remain idempotent despite rate-limiting - on room joins.""" - room_id = self.helper.create_room_as(self.user_id) - - # Let's test both paths to be sure. - paths_to_test = [ - "/_matrix/client/r0/rooms/%s/join", - "/_matrix/client/r0/join/%s", - ] - - for path in paths_to_test: - # Make sure we send more requests than the rate-limiting config would allow - # if all of these requests ended up joining the user to a room. - for _ in range(4): - channel = self.make_request("POST", path % room_id, {}) - self.assertEquals(channel.code, 200) - - @unittest.override_config( - { - "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}, - "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"], - "autocreate_auto_join_rooms": True, - }, - ) - def test_autojoin_rooms(self): - user_id = self.register_user("testuser", "password") - - # Check that the new user successfully joined the four rooms - rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id)) - self.assertEqual(len(rooms), 4) - - -class RoomMessagesTestCase(RoomBase): - """Tests /rooms/$room_id/messages/$user_id/$msg_id REST events.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - self.room_id = self.helper.create_room_as(self.user_id) - - def test_invalid_puts(self): - path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) - # missing keys or invalid json - channel = self.make_request("PUT", path, b"{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b'{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b"text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b"") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - def test_rooms_messages_sent(self): - path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) - - content = b'{"body":"test","msgtype":{"type":"a"}}' - channel = self.make_request("PUT", path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - # custom message types - content = b'{"body":"test","msgtype":"test.custom.text"}' - channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # m.text message type - path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) - content = b'{"body":"test2","msgtype":"m.text"}' - channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - -class RoomInitialSyncTestCase(RoomBase): - """Tests /rooms/$room_id/initialSync.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - # create the room - self.room_id = self.helper.create_room_as(self.user_id) - - def test_initial_sync(self): - channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) - self.assertEquals(200, channel.code) - - self.assertEquals(self.room_id, channel.json_body["room_id"]) - self.assertEquals("join", channel.json_body["membership"]) - - # Room state is easier to assert on if we unpack it into a dict - state = {} - for event in channel.json_body["state"]: - if "state_key" not in event: - continue - t = event["type"] - if t not in state: - state[t] = [] - state[t].append(event) - - self.assertTrue("m.room.create" in state) - - self.assertTrue("messages" in channel.json_body) - self.assertTrue("chunk" in channel.json_body["messages"]) - self.assertTrue("end" in channel.json_body["messages"]) - - self.assertTrue("presence" in channel.json_body) - - presence_by_user = { - e["content"]["user_id"]: e for e in channel.json_body["presence"] - } - self.assertTrue(self.user_id in presence_by_user) - self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) - - -class RoomMessageListTestCase(RoomBase): - """Tests /rooms/$room_id/messages REST events.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - self.room_id = self.helper.create_room_as(self.user_id) - - def test_topo_token_is_accepted(self): - token = "t1-0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) - ) - self.assertEquals(200, channel.code) - self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) - self.assertTrue("chunk" in channel.json_body) - self.assertTrue("end" in channel.json_body) - - def test_stream_token_is_accepted_for_fwd_pagianation(self): - token = "s0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) - ) - self.assertEquals(200, channel.code) - self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) - self.assertTrue("chunk" in channel.json_body) - self.assertTrue("end" in channel.json_body) - - def test_room_messages_purge(self): - store = self.hs.get_datastore() - pagination_handler = self.hs.get_pagination_handler() - - # Send a first message in the room, which will be removed by the purge. - first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] - first_token = self.get_success( - store.get_topological_token_for_event(first_event_id) - ) - first_token_str = self.get_success(first_token.to_string(store)) - - # Send a second message in the room, which won't be removed, and which we'll - # use as the marker to purge events before. - second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] - second_token = self.get_success( - store.get_topological_token_for_event(second_event_id) - ) - second_token_str = self.get_success(second_token.to_string(store)) - - # Send a third event in the room to ensure we don't fall under any edge case - # due to our marker being the latest forward extremity in the room. - self.helper.send(self.room_id, "message 3") - - # Check that we get the first and second message when querying /messages. - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % ( - self.room_id, - second_token_str, - json.dumps({"types": [EventTypes.Message]}), - ), - ) - self.assertEqual(channel.code, 200, channel.json_body) - - chunk = channel.json_body["chunk"] - self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) - - # Purge every event before the second event. - purge_id = random_string(16) - pagination_handler._purges_by_id[purge_id] = PurgeStatus() - self.get_success( - pagination_handler._purge_history( - purge_id=purge_id, - room_id=self.room_id, - token=second_token_str, - delete_local_events=True, - ) - ) - - # Check that we only get the second message through /message now that the first - # has been purged. - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % ( - self.room_id, - second_token_str, - json.dumps({"types": [EventTypes.Message]}), - ), - ) - self.assertEqual(channel.code, 200, channel.json_body) - - chunk = channel.json_body["chunk"] - self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) - - # Check that we get no event, but also no error, when querying /messages with - # the token that was pointing at the first event, because we don't have it - # anymore. - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % ( - self.room_id, - first_token_str, - json.dumps({"types": [EventTypes.Message]}), - ), - ) - self.assertEqual(channel.code, 200, channel.json_body) - - chunk = channel.json_body["chunk"] - self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) - - -class RoomSearchTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - user_id = True - hijack_auth = False - - def prepare(self, reactor, clock, hs): - - # Register the user who does the searching - self.user_id = self.register_user("user", "pass") - self.access_token = self.login("user", "pass") - - # Register the user who sends the message - self.other_user_id = self.register_user("otheruser", "pass") - self.other_access_token = self.login("otheruser", "pass") - - # Create a room - self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) - - # Invite the other person - self.helper.invite( - room=self.room, - src=self.user_id, - tok=self.access_token, - targ=self.other_user_id, - ) - - # The other user joins - self.helper.join( - room=self.room, user=self.other_user_id, tok=self.other_access_token - ) - - def test_finds_message(self): - """ - The search functionality will search for content in messages if asked to - do so. - """ - # The other user sends some messages - self.helper.send(self.room, body="Hi!", tok=self.other_access_token) - self.helper.send(self.room, body="There!", tok=self.other_access_token) - - channel = self.make_request( - "POST", - "/search?access_token=%s" % (self.access_token,), - { - "search_categories": { - "room_events": {"keys": ["content.body"], "search_term": "Hi"} - } - }, - ) - - # Check we get the results we expect -- one search result, of the sent - # messages - self.assertEqual(channel.code, 200) - results = channel.json_body["search_categories"]["room_events"] - self.assertEqual(results["count"], 1) - self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") - - # No context was requested, so we should get none. - self.assertEqual(results["results"][0]["context"], {}) - - def test_include_context(self): - """ - When event_context includes include_profile, profile information will be - included in the search response. - """ - # The other user sends some messages - self.helper.send(self.room, body="Hi!", tok=self.other_access_token) - self.helper.send(self.room, body="There!", tok=self.other_access_token) - - channel = self.make_request( - "POST", - "/search?access_token=%s" % (self.access_token,), - { - "search_categories": { - "room_events": { - "keys": ["content.body"], - "search_term": "Hi", - "event_context": {"include_profile": True}, - } - } - }, - ) - - # Check we get the results we expect -- one search result, of the sent - # messages - self.assertEqual(channel.code, 200) - results = channel.json_body["search_categories"]["room_events"] - self.assertEqual(results["count"], 1) - self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") - - # We should get context info, like the two users, and the display names. - context = results["results"][0]["context"] - self.assertEqual(len(context["profile_info"].keys()), 2) - self.assertEqual( - context["profile_info"][self.other_user_id]["displayname"], "otheruser" - ) - - -class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - - self.url = b"/_matrix/client/r0/publicRooms" - - config = self.default_config() - config["allow_public_rooms_without_auth"] = False - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def test_restricted_no_auth(self): - channel = self.make_request("GET", self.url) - self.assertEqual(channel.code, 401, channel.result) - - def test_restricted_auth(self): - self.register_user("user", "pass") - tok = self.login("user", "pass") - - channel = self.make_request("GET", self.url, access_token=tok) - self.assertEqual(channel.code, 200, channel.result) - - -class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): - """Test that we correctly fallback to local filtering if a remote server - doesn't support search. - """ - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - return self.setup_test_homeserver(federation_client=Mock()) - - def prepare(self, reactor, clock, hs): - self.register_user("user", "pass") - self.token = self.login("user", "pass") - - self.federation_client = hs.get_federation_client() - - def test_simple(self): - "Simple test for searching rooms over federation" - self.federation_client.get_public_rooms.side_effect = ( - lambda *a, **k: defer.succeed({}) - ) - - search_filter = {"generic_search_term": "foobar"} - - channel = self.make_request( - "POST", - b"/_matrix/client/r0/publicRooms?server=testserv", - content={"filter": search_filter}, - access_token=self.token, - ) - self.assertEqual(channel.code, 200, channel.result) - - self.federation_client.get_public_rooms.assert_called_once_with( - "testserv", - limit=100, - since_token=None, - search_filter=search_filter, - include_all_networks=False, - third_party_instance_id=None, - ) - - def test_fallback(self): - "Test that searching public rooms over federation falls back if it gets a 404" - - # The `get_public_rooms` should be called again if the first call fails - # with a 404, when using search filters. - self.federation_client.get_public_rooms.side_effect = ( - HttpResponseException(404, "Not Found", b""), - defer.succeed({}), - ) - - search_filter = {"generic_search_term": "foobar"} - - channel = self.make_request( - "POST", - b"/_matrix/client/r0/publicRooms?server=testserv", - content={"filter": search_filter}, - access_token=self.token, - ) - self.assertEqual(channel.code, 200, channel.result) - - self.federation_client.get_public_rooms.assert_has_calls( - [ - call( - "testserv", - limit=100, - since_token=None, - search_filter=search_filter, - include_all_networks=False, - third_party_instance_id=None, - ), - call( - "testserv", - limit=None, - since_token=None, - search_filter=None, - include_all_networks=False, - third_party_instance_id=None, - ), - ] - ) - - -class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - profile.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["allow_per_room_profiles"] = False - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def prepare(self, reactor, clock, homeserver): - self.user_id = self.register_user("test", "test") - self.tok = self.login("test", "test") - - # Set a profile for the test user - self.displayname = "test user" - data = {"displayname": self.displayname} - request_data = json.dumps(data) - channel = self.make_request( - "PUT", - "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), - request_data, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - def test_per_room_profile_forbidden(self): - data = {"membership": "join", "displayname": "other test user"} - request_data = json.dumps(data) - channel = self.make_request( - "PUT", - "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" - % (self.room_id, self.user_id), - request_data, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - event_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - res_displayname = channel.json_body["content"]["displayname"] - self.assertEqual(res_displayname, self.displayname, channel.result) - - -class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): - """Tests that clients can add a "reason" field to membership events and - that they get correctly added to the generated events and propagated. - """ - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.creator = self.register_user("creator", "test") - self.creator_tok = self.login("creator", "test") - - self.second_user_id = self.register_user("second", "test") - self.second_tok = self.login("second", "test") - - self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) - - def test_join_reason(self): - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/join", - content={"reason": reason}, - access_token=self.second_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_leave_reason(self): - self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) - - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/leave", - content={"reason": reason}, - access_token=self.second_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_kick_reason(self): - self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) - - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/kick", - content={"reason": reason, "user_id": self.second_user_id}, - access_token=self.second_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_ban_reason(self): - self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) - - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/ban", - content={"reason": reason, "user_id": self.second_user_id}, - access_token=self.creator_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_unban_reason(self): - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/unban", - content={"reason": reason, "user_id": self.second_user_id}, - access_token=self.creator_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_invite_reason(self): - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/invite", - content={"reason": reason, "user_id": self.second_user_id}, - access_token=self.creator_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_reject_invite_reason(self): - self.helper.invite( - self.room_id, - src=self.creator, - targ=self.second_user_id, - tok=self.creator_tok, - ) - - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/leave", - content={"reason": reason}, - access_token=self.second_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def _check_for_reason(self, reason): - channel = self.make_request( - "GET", - "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( - self.room_id, self.second_user_id - ), - access_token=self.creator_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - event_content = channel.json_body - - self.assertEqual(event_content.get("reason"), reason, channel.result) - - -class LabelsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - profile.register_servlets, - ] - - # Filter that should only catch messages with the label "#fun". - FILTER_LABELS = { - "types": [EventTypes.Message], - "org.matrix.labels": ["#fun"], - } - # Filter that should only catch messages without the label "#fun". - FILTER_NOT_LABELS = { - "types": [EventTypes.Message], - "org.matrix.not_labels": ["#fun"], - } - # Filter that should only catch messages with the label "#work" but without the label - # "#notfun". - FILTER_LABELS_NOT_LABELS = { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - - def prepare(self, reactor, clock, homeserver): - self.user_id = self.register_user("test", "test") - self.tok = self.login("test", "test") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - def test_context_filter_labels(self): - """Test that we can filter by a label on a /context request.""" - event_id = self._send_labelled_messages_in_room() - - channel = self.make_request( - "GET", - "/rooms/%s/context/%s?filter=%s" - % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual( - len(events_before), 1, [event["content"] for event in events_before] - ) - self.assertEqual( - events_before[0]["content"]["body"], "with right label", events_before[0] - ) - - events_after = channel.json_body["events_before"] - - self.assertEqual( - len(events_after), 1, [event["content"] for event in events_after] - ) - self.assertEqual( - events_after[0]["content"]["body"], "with right label", events_after[0] - ) - - def test_context_filter_not_labels(self): - """Test that we can filter by the absence of a label on a /context request.""" - event_id = self._send_labelled_messages_in_room() - - channel = self.make_request( - "GET", - "/rooms/%s/context/%s?filter=%s" - % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual( - len(events_before), 1, [event["content"] for event in events_before] - ) - self.assertEqual( - events_before[0]["content"]["body"], "without label", events_before[0] - ) - - events_after = channel.json_body["events_after"] - - self.assertEqual( - len(events_after), 2, [event["content"] for event in events_after] - ) - self.assertEqual( - events_after[0]["content"]["body"], "with wrong label", events_after[0] - ) - self.assertEqual( - events_after[1]["content"]["body"], "with two wrong labels", events_after[1] - ) - - def test_context_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label on a - /context request. - """ - event_id = self._send_labelled_messages_in_room() - - channel = self.make_request( - "GET", - "/rooms/%s/context/%s?filter=%s" - % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual( - len(events_before), 0, [event["content"] for event in events_before] - ) - - events_after = channel.json_body["events_after"] - - self.assertEqual( - len(events_after), 1, [event["content"] for event in events_after] - ) - self.assertEqual( - events_after[0]["content"]["body"], "with wrong label", events_after[0] - ) - - def test_messages_filter_labels(self): - """Test that we can filter by a label on a /messages request.""" - self._send_labelled_messages_in_room() - - token = "s0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), - ) - - events = channel.json_body["chunk"] - - self.assertEqual(len(events), 2, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - - def test_messages_filter_not_labels(self): - """Test that we can filter by the absence of a label on a /messages request.""" - self._send_labelled_messages_in_room() - - token = "s0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), - ) - - events = channel.json_body["chunk"] - - self.assertEqual(len(events), 4, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "without label", events[0]) - self.assertEqual(events[1]["content"]["body"], "without label", events[1]) - self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2]) - self.assertEqual( - events[3]["content"]["body"], "with two wrong labels", events[3] - ) - - def test_messages_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label on a - /messages request. - """ - self._send_labelled_messages_in_room() - - token = "s0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % ( - self.room_id, - self.tok, - token, - json.dumps(self.FILTER_LABELS_NOT_LABELS), - ), - ) - - events = channel.json_body["chunk"] - - self.assertEqual(len(events), 1, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - - def test_search_filter_labels(self): - """Test that we can filter by a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS, - } - } - } - ) - - self._send_labelled_messages_in_room() - - channel = self.make_request( - "POST", "/search?access_token=%s" % self.tok, request_data - ) - - results = channel.json_body["search_categories"]["room_events"]["results"] - - self.assertEqual( - len(results), - 2, - [result["result"]["content"] for result in results], - ) - self.assertEqual( - results[0]["result"]["content"]["body"], - "with right label", - results[0]["result"]["content"]["body"], - ) - self.assertEqual( - results[1]["result"]["content"]["body"], - "with right label", - results[1]["result"]["content"]["body"], - ) - - def test_search_filter_not_labels(self): - """Test that we can filter by the absence of a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_NOT_LABELS, - } - } - } - ) - - self._send_labelled_messages_in_room() - - channel = self.make_request( - "POST", "/search?access_token=%s" % self.tok, request_data - ) - - results = channel.json_body["search_categories"]["room_events"]["results"] - - self.assertEqual( - len(results), - 4, - [result["result"]["content"] for result in results], - ) - self.assertEqual( - results[0]["result"]["content"]["body"], - "without label", - results[0]["result"]["content"]["body"], - ) - self.assertEqual( - results[1]["result"]["content"]["body"], - "without label", - results[1]["result"]["content"]["body"], - ) - self.assertEqual( - results[2]["result"]["content"]["body"], - "with wrong label", - results[2]["result"]["content"]["body"], - ) - self.assertEqual( - results[3]["result"]["content"]["body"], - "with two wrong labels", - results[3]["result"]["content"]["body"], - ) - - def test_search_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label on a - /search request. - """ - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS_NOT_LABELS, - } - } - } - ) - - self._send_labelled_messages_in_room() - - channel = self.make_request( - "POST", "/search?access_token=%s" % self.tok, request_data - ) - - results = channel.json_body["search_categories"]["room_events"]["results"] - - self.assertEqual( - len(results), - 1, - [result["result"]["content"] for result in results], - ) - self.assertEqual( - results[0]["result"]["content"]["body"], - "with wrong label", - results[0]["result"]["content"]["body"], - ) - - def _send_labelled_messages_in_room(self): - """Sends several messages to a room with different labels (or without any) to test - filtering by label. - Returns: - The ID of the event to use if we're testing filtering on /context. - """ - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=self.tok, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - tok=self.tok, - ) - - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - tok=self.tok, - ) - # Return this event's ID when we test filtering in /context requests. - event_id = res["event_id"] - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with wrong label", - EventContentFields.LABELS: ["#work"], - }, - tok=self.tok, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with two wrong labels", - EventContentFields.LABELS: ["#work", "#notfun"], - }, - tok=self.tok, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=self.tok, - ) - - return event_id - - -class ContextTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - account.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.user_id = self.register_user("user", "password") - self.tok = self.login("user", "password") - self.room_id = self.helper.create_room_as( - self.user_id, tok=self.tok, is_public=False - ) - - self.other_user_id = self.register_user("user2", "password") - self.other_tok = self.login("user2", "password") - - self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) - self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) - - def test_erased_sender(self): - """Test that an erasure request results in the requester's events being hidden - from any new member of the room. - """ - - # Send a bunch of events in the room. - - self.helper.send(self.room_id, "message 1", tok=self.tok) - self.helper.send(self.room_id, "message 2", tok=self.tok) - event_id = self.helper.send(self.room_id, "message 3", tok=self.tok)["event_id"] - self.helper.send(self.room_id, "message 4", tok=self.tok) - self.helper.send(self.room_id, "message 5", tok=self.tok) - - # Check that we can still see the messages before the erasure request. - - channel = self.make_request( - "GET", - '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' - % (self.room_id, event_id), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual(len(events_before), 2, events_before) - self.assertEqual( - events_before[0].get("content", {}).get("body"), - "message 2", - events_before[0], - ) - self.assertEqual( - events_before[1].get("content", {}).get("body"), - "message 1", - events_before[1], - ) - - self.assertEqual( - channel.json_body["event"].get("content", {}).get("body"), - "message 3", - channel.json_body["event"], - ) - - events_after = channel.json_body["events_after"] - - self.assertEqual(len(events_after), 2, events_after) - self.assertEqual( - events_after[0].get("content", {}).get("body"), - "message 4", - events_after[0], - ) - self.assertEqual( - events_after[1].get("content", {}).get("body"), - "message 5", - events_after[1], - ) - - # Deactivate the first account and erase the user's data. - - deactivate_account_handler = self.hs.get_deactivate_account_handler() - self.get_success( - deactivate_account_handler.deactivate_account( - self.user_id, True, create_requester(self.user_id) - ) - ) - - # Invite another user in the room. This is needed because messages will be - # pruned only if the user wasn't a member of the room when the messages were - # sent. - - invited_user_id = self.register_user("user3", "password") - invited_tok = self.login("user3", "password") - - self.helper.invite( - self.room_id, self.other_user_id, invited_user_id, tok=self.other_tok - ) - self.helper.join(self.room_id, invited_user_id, tok=invited_tok) - - # Check that a user that joined the room after the erasure request can't see - # the messages anymore. - - channel = self.make_request( - "GET", - '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' - % (self.room_id, event_id), - access_token=invited_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual(len(events_before), 2, events_before) - self.assertDictEqual(events_before[0].get("content"), {}, events_before[0]) - self.assertDictEqual(events_before[1].get("content"), {}, events_before[1]) - - self.assertDictEqual( - channel.json_body["event"].get("content"), {}, channel.json_body["event"] - ) - - events_after = channel.json_body["events_after"] - - self.assertEqual(len(events_after), 2, events_after) - self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) - self.assertEqual(events_after[1].get("content"), {}, events_after[1]) - - -class RoomAliasListTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - directory.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.room_owner = self.register_user("room_owner", "test") - self.room_owner_tok = self.login("room_owner", "test") - - self.room_id = self.helper.create_room_as( - self.room_owner, tok=self.room_owner_tok - ) - - def test_no_aliases(self): - res = self._get_aliases(self.room_owner_tok) - self.assertEqual(res["aliases"], []) - - def test_not_in_room(self): - self.register_user("user", "test") - user_tok = self.login("user", "test") - res = self._get_aliases(user_tok, expected_code=403) - self.assertEqual(res["errcode"], "M_FORBIDDEN") - - def test_admin_user(self): - alias1 = self._random_alias() - self._set_alias_via_directory(alias1) - - self.register_user("user", "test", admin=True) - user_tok = self.login("user", "test") - - res = self._get_aliases(user_tok) - self.assertEqual(res["aliases"], [alias1]) - - def test_with_aliases(self): - alias1 = self._random_alias() - alias2 = self._random_alias() - - self._set_alias_via_directory(alias1) - self._set_alias_via_directory(alias2) - - res = self._get_aliases(self.room_owner_tok) - self.assertEqual(set(res["aliases"]), {alias1, alias2}) - - def test_peekable_room(self): - alias1 = self._random_alias() - self._set_alias_via_directory(alias1) - - self.helper.send_state( - self.room_id, - EventTypes.RoomHistoryVisibility, - body={"history_visibility": "world_readable"}, - tok=self.room_owner_tok, - ) - - self.register_user("user", "test") - user_tok = self.login("user", "test") - - res = self._get_aliases(user_tok) - self.assertEqual(res["aliases"], [alias1]) - - def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: - """Calls the endpoint under test. returns the json response object.""" - channel = self.make_request( - "GET", - "/_matrix/client/r0/rooms/%s/aliases" % (self.room_id,), - access_token=access_token, - ) - self.assertEqual(channel.code, expected_code, channel.result) - res = channel.json_body - self.assertIsInstance(res, dict) - if expected_code == 200: - self.assertIsInstance(res["aliases"], list) - return res - - def _random_alias(self) -> str: - return RoomAlias(random_string(5), self.hs.hostname).to_string() - - def _set_alias_via_directory(self, alias: str, expected_code: int = 200): - url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) - - channel = self.make_request( - "PUT", url, request_data, access_token=self.room_owner_tok - ) - self.assertEqual(channel.code, expected_code, channel.result) - - -class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - directory.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.room_owner = self.register_user("room_owner", "test") - self.room_owner_tok = self.login("room_owner", "test") - - self.room_id = self.helper.create_room_as( - self.room_owner, tok=self.room_owner_tok - ) - - self.alias = "#alias:test" - self._set_alias_via_directory(self.alias) - - def _set_alias_via_directory(self, alias: str, expected_code: int = 200): - url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) - - channel = self.make_request( - "PUT", url, request_data, access_token=self.room_owner_tok - ) - self.assertEqual(channel.code, expected_code, channel.result) - - def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: - """Calls the endpoint under test. returns the json response object.""" - channel = self.make_request( - "GET", - "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), - access_token=self.room_owner_tok, - ) - self.assertEqual(channel.code, expected_code, channel.result) - res = channel.json_body - self.assertIsInstance(res, dict) - return res - - def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: - """Calls the endpoint under test. returns the json response object.""" - channel = self.make_request( - "PUT", - "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), - json.dumps(content), - access_token=self.room_owner_tok, - ) - self.assertEqual(channel.code, expected_code, channel.result) - res = channel.json_body - self.assertIsInstance(res, dict) - return res - - def test_canonical_alias(self): - """Test a basic alias message.""" - # There is no canonical alias to start with. - self._get_canonical_alias(expected_code=404) - - # Create an alias. - self._set_canonical_alias({"alias": self.alias}) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual(res, {"alias": self.alias}) - - # Now remove the alias. - self._set_canonical_alias({}) - - # There is an alias event, but it is empty. - res = self._get_canonical_alias() - self.assertEqual(res, {}) - - def test_alt_aliases(self): - """Test a canonical alias message with alt_aliases.""" - # Create an alias. - self._set_canonical_alias({"alt_aliases": [self.alias]}) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual(res, {"alt_aliases": [self.alias]}) - - # Now remove the alt_aliases. - self._set_canonical_alias({}) - - # There is an alias event, but it is empty. - res = self._get_canonical_alias() - self.assertEqual(res, {}) - - def test_alias_alt_aliases(self): - """Test a canonical alias message with an alias and alt_aliases.""" - # Create an alias. - self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) - - # Now remove the alias and alt_aliases. - self._set_canonical_alias({}) - - # There is an alias event, but it is empty. - res = self._get_canonical_alias() - self.assertEqual(res, {}) - - def test_partial_modify(self): - """Test removing only the alt_aliases.""" - # Create an alias. - self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) - - # Now remove the alt_aliases. - self._set_canonical_alias({"alias": self.alias}) - - # There is an alias event, but it is empty. - res = self._get_canonical_alias() - self.assertEqual(res, {"alias": self.alias}) - - def test_add_alias(self): - """Test removing only the alt_aliases.""" - # Create an additional alias. - second_alias = "#second:test" - self._set_alias_via_directory(second_alias) - - # Add the canonical alias. - self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) - - # Then add the second alias. - self._set_canonical_alias( - {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} - ) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual( - res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} - ) - - def test_bad_data(self): - """Invalid data for alt_aliases should cause errors.""" - self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) - self._set_canonical_alias({"alt_aliases": None}, expected_code=400) - self._set_canonical_alias({"alt_aliases": 0}, expected_code=400) - self._set_canonical_alias({"alt_aliases": 1}, expected_code=400) - self._set_canonical_alias({"alt_aliases": False}, expected_code=400) - self._set_canonical_alias({"alt_aliases": True}, expected_code=400) - self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) - - def test_bad_alias(self): - """An alias which does not point to the room raises a SynapseError.""" - self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) - self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py deleted file mode 100644 index b54b004733..0000000000 --- a/tests/rest/client/v1/test_typing.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests REST events for /rooms paths.""" - -from unittest.mock import Mock - -from synapse.rest.client import room -from synapse.types import UserID - -from tests import unittest - -PATH_PREFIX = "/_matrix/client/api/v1" - - -class RoomTypingTestCase(unittest.HomeserverTestCase): - """Tests /rooms/$room_id/typing/$user_id REST API.""" - - user_id = "@sid:red" - - user = UserID.from_string(user_id) - servlets = [room.register_servlets] - - def make_homeserver(self, reactor, clock): - - hs = self.setup_test_homeserver( - "red", - federation_http_client=None, - federation_client=Mock(), - ) - - self.event_source = hs.get_event_sources().sources["typing"] - - hs.get_federation_handler = Mock() - - async def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - async def _insert_client_ip(*args, **kwargs): - return None - - hs.get_datastore().insert_client_ip = _insert_client_ip - - return hs - - def prepare(self, reactor, clock, hs): - self.room_id = self.helper.create_room_as(self.user_id) - # Need another user to make notifications actually work - self.helper.join(self.room_id, user="@jim:red") - - def test_set_typing(self): - channel = self.make_request( - "PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - self.assertEquals(self.event_source.get_current_key(), 1) - events = self.get_success( - self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) - ) - self.assertEquals( - events[0], - [ - { - "type": "m.typing", - "room_id": self.room_id, - "content": {"user_ids": [self.user_id]}, - } - ], - ) - - def test_set_not_typing(self): - channel = self.make_request( - "PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - b'{"typing": false}', - ) - self.assertEquals(200, channel.code) - - def test_typing_timeout(self): - channel = self.make_request( - "PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - self.assertEquals(self.event_source.get_current_key(), 1) - - self.reactor.advance(36) - - self.assertEquals(self.event_source.get_current_key(), 2) - - channel = self.make_request( - "PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - self.assertEquals(self.event_source.get_current_key(), 3) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py deleted file mode 100644 index 954ad1a1fd..0000000000 --- a/tests/rest/client/v1/utils.py +++ /dev/null @@ -1,654 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import re -import time -import urllib.parse -from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union -from unittest.mock import patch - -import attr - -from twisted.web.resource import Resource -from twisted.web.server import Site - -from synapse.api.constants import Membership -from synapse.types import JsonDict - -from tests.server import FakeChannel, FakeSite, make_request -from tests.test_utils import FakeResponse -from tests.test_utils.html_parsers import TestHtmlParser - - -@attr.s -class RestHelper: - """Contains extra helper functions to quickly and clearly perform a given - REST action, which isn't the focus of the test. - """ - - hs = attr.ib() - site = attr.ib(type=Site) - auth_user_id = attr.ib() - - def create_room_as( - self, - room_creator: Optional[str] = None, - is_public: bool = True, - room_version: Optional[str] = None, - tok: Optional[str] = None, - expect_code: int = 200, - extra_content: Optional[Dict] = None, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, - ) -> str: - """ - Create a room. - - Args: - room_creator: The user ID to create the room with. - is_public: If True, the `visibility` parameter will be set to the - default (public). Otherwise, the `visibility` parameter will be set - to "private". - room_version: The room version to create the room as. Defaults to Synapse's - default room version. - tok: The access token to use in the request. - expect_code: The expected HTTP response code. - - Returns: - The ID of the newly created room. - """ - temp_id = self.auth_user_id - self.auth_user_id = room_creator - path = "/_matrix/client/r0/createRoom" - content = extra_content or {} - if not is_public: - content["visibility"] = "private" - if room_version: - content["room_version"] = room_version - if tok: - path = path + "?access_token=%s" % tok - - channel = make_request( - self.hs.get_reactor(), - self.site, - "POST", - path, - json.dumps(content).encode("utf8"), - custom_headers=custom_headers, - ) - - assert channel.result["code"] == b"%d" % expect_code, channel.result - self.auth_user_id = temp_id - - if expect_code == 200: - return channel.json_body["room_id"] - - def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): - self.change_membership( - room=room, - src=src, - targ=targ, - tok=tok, - membership=Membership.INVITE, - expect_code=expect_code, - ) - - def join(self, room=None, user=None, expect_code=200, tok=None): - self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.JOIN, - expect_code=expect_code, - ) - - def leave(self, room=None, user=None, expect_code=200, tok=None): - self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.LEAVE, - expect_code=expect_code, - ) - - def change_membership( - self, - room: str, - src: str, - targ: str, - membership: str, - extra_data: Optional[dict] = None, - tok: Optional[str] = None, - expect_code: int = 200, - ) -> None: - """ - Send a membership state event into a room. - - Args: - room: The ID of the room to send to - src: The mxid of the event sender - targ: The mxid of the event's target. The state key - membership: The type of membership event - extra_data: Extra information to include in the content of the event - tok: The user access token to use - expect_code: The expected HTTP response code - """ - temp_id = self.auth_user_id - self.auth_user_id = src - - path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ) - if tok: - path = path + "?access_token=%s" % tok - - data = {"membership": membership} - data.update(extra_data or {}) - - channel = make_request( - self.hs.get_reactor(), - self.site, - "PUT", - path, - json.dumps(data).encode("utf8"), - ) - - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( - expect_code, - int(channel.result["code"]), - channel.result["body"], - ) - - self.auth_user_id = temp_id - - def send( - self, - room_id, - body=None, - txn_id=None, - tok=None, - expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, - ): - if body is None: - body = "body_text_here" - - content = {"msgtype": "m.text", "body": body} - - return self.send_event( - room_id, - "m.room.message", - content, - txn_id, - tok, - expect_code, - custom_headers=custom_headers, - ) - - def send_event( - self, - room_id, - type, - content: Optional[dict] = None, - txn_id=None, - tok=None, - expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, - ): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) - - path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id) - if tok: - path = path + "?access_token=%s" % tok - - channel = make_request( - self.hs.get_reactor(), - self.site, - "PUT", - path, - json.dumps(content or {}).encode("utf8"), - custom_headers=custom_headers, - ) - - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( - expect_code, - int(channel.result["code"]), - channel.result["body"], - ) - - return channel.json_body - - def _read_write_state( - self, - room_id: str, - event_type: str, - body: Optional[Dict[str, Any]], - tok: str, - expect_code: int = 200, - state_key: str = "", - method: str = "GET", - ) -> Dict: - """Read or write some state from a given room - - Args: - room_id: - event_type: The type of state event - body: Body that is sent when making the request. The content of the state event. - If None, the request to the server will have an empty body - tok: The access token to use - expect_code: The HTTP code to expect in the response - state_key: - method: "GET" or "PUT" for reading or writing state, respectively - - Returns: - The response body from the server - - Raises: - AssertionError: if expect_code doesn't match the HTTP code we received - """ - path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( - room_id, - event_type, - state_key, - ) - if tok: - path = path + "?access_token=%s" % tok - - # Set request body if provided - content = b"" - if body is not None: - content = json.dumps(body).encode("utf8") - - channel = make_request(self.hs.get_reactor(), self.site, method, path, content) - - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( - expect_code, - int(channel.result["code"]), - channel.result["body"], - ) - - return channel.json_body - - def get_state( - self, - room_id: str, - event_type: str, - tok: str, - expect_code: int = 200, - state_key: str = "", - ): - """Gets some state from a room - - Args: - room_id: - event_type: The type of state event - tok: The access token to use - expect_code: The HTTP code to expect in the response - state_key: - - Returns: - The response body from the server - - Raises: - AssertionError: if expect_code doesn't match the HTTP code we received - """ - return self._read_write_state( - room_id, event_type, None, tok, expect_code, state_key, method="GET" - ) - - def send_state( - self, - room_id: str, - event_type: str, - body: Dict[str, Any], - tok: str, - expect_code: int = 200, - state_key: str = "", - ): - """Set some state in a room - - Args: - room_id: - event_type: The type of state event - body: Body that is sent when making the request. The content of the state event. - tok: The access token to use - expect_code: The HTTP code to expect in the response - state_key: - - Returns: - The response body from the server - - Raises: - AssertionError: if expect_code doesn't match the HTTP code we received - """ - return self._read_write_state( - room_id, event_type, body, tok, expect_code, state_key, method="PUT" - ) - - def upload_media( - self, - resource: Resource, - image_data: bytes, - tok: str, - filename: str = "test.png", - expect_code: int = 200, - ) -> dict: - """Upload a piece of test media to the media repo - Args: - resource: The resource that will handle the upload request - image_data: The image data to upload - tok: The user token to use during the upload - filename: The filename of the media to be uploaded - expect_code: The return code to expect from attempting to upload the media - """ - image_length = len(image_data) - path = "/_matrix/media/r0/upload?filename=%s" % (filename,) - channel = make_request( - self.hs.get_reactor(), - FakeSite(resource), - "POST", - path, - content=image_data, - access_token=tok, - custom_headers=[(b"Content-Length", str(image_length))], - ) - - assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( - expect_code, - int(channel.result["code"]), - channel.result["body"], - ) - - return channel.json_body - - def login_via_oidc(self, remote_user_id: str) -> JsonDict: - """Log in (as a new user) via OIDC - - Returns the result of the final token login. - - Requires that "oidc_config" in the homeserver config be set appropriately - (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a - "public_base_url". - - Also requires the login servlet and the OIDC callback resource to be mounted at - the normal places. - """ - client_redirect_url = "https://x" - channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) - - # expect a confirmation page - assert channel.code == 200, channel.result - - # fish the matrix login token out of the body of the confirmation page - m = re.search( - 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), - channel.text_body, - ) - assert m, channel.text_body - login_token = m.group(1) - - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token and device id. - channel = make_request( - self.hs.get_reactor(), - self.site, - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - assert channel.code == 200 - return channel.json_body - - def auth_via_oidc( - self, - user_info_dict: JsonDict, - client_redirect_url: Optional[str] = None, - ui_auth_session_id: Optional[str] = None, - ) -> FakeChannel: - """Perform an OIDC authentication flow via a mock OIDC provider. - - This can be used for either login or user-interactive auth. - - Starts by making a request to the relevant synapse redirect endpoint, which is - expected to serve a 302 to the OIDC provider. We then make a request to the - OIDC callback endpoint, intercepting the HTTP requests that will get sent back - to the OIDC provider. - - Requires that "oidc_config" in the homeserver config be set appropriately - (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a - "public_base_url". - - Also requires the login servlet and the OIDC callback resource to be mounted at - the normal places. - - Args: - user_info_dict: the remote userinfo that the OIDC provider should present. - Typically this should be '{"sub": ""}'. - client_redirect_url: for a login flow, the client redirect URL to pass to - the login redirect endpoint - ui_auth_session_id: if set, we will perform a UI Auth flow. The session id - of the UI auth. - - Returns: - A FakeChannel containing the result of calling the OIDC callback endpoint. - Note that the response code may be a 200, 302 or 400 depending on how things - went. - """ - - cookies = {} - - # if we're doing a ui auth, hit the ui auth redirect endpoint - if ui_auth_session_id: - # can't set the client redirect url for UI Auth - assert client_redirect_url is None - oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) - else: - # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) - - # we now have a URI for the OIDC IdP, but we skip that and go straight - # back to synapse's OIDC callback resource. However, we do need the "state" - # param that synapse passes to the IdP via query params, as well as the cookie - # that synapse passes to the client. - - oauth_uri_path, _ = oauth_uri.split("?", 1) - assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( - "unexpected SSO URI " + oauth_uri_path - ) - return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) - - def complete_oidc_auth( - self, - oauth_uri: str, - cookies: Mapping[str, str], - user_info_dict: JsonDict, - ) -> FakeChannel: - """Mock out an OIDC authentication flow - - Assumes that an OIDC auth has been initiated by one of initiate_sso_login or - initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to - Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get - sent back to the OIDC provider. - - Requires the OIDC callback resource to be mounted at the normal place. - - Args: - oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, - from initiate_sso_login or initiate_sso_ui_auth). - cookies: the cookies set by synapse's redirect endpoint, which will be - sent back to the callback endpoint. - user_info_dict: the remote userinfo that the OIDC provider should present. - Typically this should be '{"sub": ""}'. - - Returns: - A FakeChannel containing the result of calling the OIDC callback endpoint. - """ - _, oauth_uri_qs = oauth_uri.split("?", 1) - params = urllib.parse.parse_qs(oauth_uri_qs) - callback_uri = "%s?%s" % ( - urllib.parse.urlparse(params["redirect_uri"][0]).path, - urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), - ) - - # before we hit the callback uri, stub out some methods in the http client so - # that we don't have to handle full HTTPS requests. - # (expected url, json response) pairs, in the order we expect them. - expected_requests = [ - # first we get a hit to the token endpoint, which we tell to return - # a dummy OIDC access token - (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), - # and then one to the user_info endpoint, which returns our remote user id. - (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), - ] - - async def mock_req(method: str, uri: str, data=None, headers=None): - (expected_uri, resp_obj) = expected_requests.pop(0) - assert uri == expected_uri - resp = FakeResponse( - code=200, - phrase=b"OK", - body=json.dumps(resp_obj).encode("utf-8"), - ) - return resp - - with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): - # now hit the callback URI with the right params and a made-up code - channel = make_request( - self.hs.get_reactor(), - self.site, - "GET", - callback_uri, - custom_headers=[ - ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() - ], - ) - return channel - - def initiate_sso_login( - self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] - ) -> str: - """Make a request to the login-via-sso redirect endpoint, and return the target - - Assumes that exactly one SSO provider has been configured. Requires the login - servlet to be mounted. - - Args: - client_redirect_url: the client redirect URL to pass to the login redirect - endpoint - cookies: any cookies returned will be added to this dict - - Returns: - the URI that the client gets redirected to (ie, the SSO server) - """ - params = {} - if client_redirect_url: - params["redirectUrl"] = client_redirect_url - - # hit the redirect url (which should redirect back to the redirect url. This - # is the easiest way of figuring out what the Host header ought to be set to - # to keep Synapse happy. - channel = make_request( - self.hs.get_reactor(), - self.site, - "GET", - "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), - ) - assert channel.code == 302 - - # hit the redirect url again with the right Host header, which should now issue - # a cookie and redirect to the SSO provider. - location = channel.headers.getRawHeaders("Location")[0] - parts = urllib.parse.urlsplit(location) - channel = make_request( - self.hs.get_reactor(), - self.site, - "GET", - urllib.parse.urlunsplit(("", "") + parts[2:]), - custom_headers=[ - ("Host", parts[1]), - ], - ) - - assert channel.code == 302 - channel.extract_cookies(cookies) - return channel.headers.getRawHeaders("Location")[0] - - def initiate_sso_ui_auth( - self, ui_auth_session_id: str, cookies: MutableMapping[str, str] - ) -> str: - """Make a request to the ui-auth-via-sso endpoint, and return the target - - Assumes that exactly one SSO provider has been configured. Requires the - AuthRestServlet to be mounted. - - Args: - ui_auth_session_id: the session id of the UI auth - cookies: any cookies returned will be added to this dict - - Returns: - the URI that the client gets linked to (ie, the SSO server) - """ - sso_redirect_endpoint = ( - "/_matrix/client/r0/auth/m.login.sso/fallback/web?" - + urllib.parse.urlencode({"session": ui_auth_session_id}) - ) - # hit the redirect url (which will issue a cookie and state) - channel = make_request( - self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint - ) - # that should serve a confirmation page - assert channel.code == 200, channel.text_body - channel.extract_cookies(cookies) - - # parse the confirmation page to fish out the link. - p = TestHtmlParser() - p.feed(channel.text_body) - p.close() - assert len(p.links) == 1, "not exactly one link in confirmation page" - oauth_uri = p.links[0] - return oauth_uri - - -# an 'oidc_config' suitable for login_via_oidc. -TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" -TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" -TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" -TEST_OIDC_CONFIG = { - "enabled": True, - "discover": False, - "issuer": "https://issuer.test", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "scopes": ["profile"], - "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, - "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, - "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, -} diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py deleted file mode 100644 index b946fca8b3..0000000000 --- a/tests/rest/client/v2_alpha/test_account.py +++ /dev/null @@ -1,1000 +0,0 @@ -# Copyright 2015-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import os -import re -from email.parser import Parser -from typing import Optional - -import pkg_resources - -import synapse.rest.admin -from synapse.api.constants import LoginType, Membership -from synapse.api.errors import Codes, HttpResponseException -from synapse.appservice import ApplicationService -from synapse.rest.client import account, login, register, room -from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource - -from tests import unittest -from tests.server import FakeSite, make_request -from tests.unittest import override_config - - -class PasswordResetTestCase(unittest.HomeserverTestCase): - - servlets = [ - account.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - register.register_servlets, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - - # Email config. - config["email"] = { - "enable_notifs": False, - "template_dir": os.path.abspath( - pkg_resources.resource_filename("synapse", "res/templates") - ), - "smtp_host": "127.0.0.1", - "smtp_port": 20, - "require_transport_security": False, - "smtp_user": None, - "smtp_pass": None, - "notif_from": "test@example.com", - } - config["public_baseurl"] = "https://example.com" - - hs = self.setup_test_homeserver(config=config) - - async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) - - self.email_attempts = [] - hs.get_send_email_handler()._sendmail = sendmail - - return hs - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.submit_token_resource = PasswordResetSubmitTokenResource(hs) - - def test_basic_password_reset(self): - """Test basic password reset flow""" - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email = "test@example.com" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - client_secret = "foobar" - session_id = self._request_token(email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - self._reset_password(new_password, session_id, client_secret) - - # Assert we can log in with the new password - self.login("kermit", new_password) - - # Assert we can't log in with the old password - self.attempt_wrong_password_login("kermit", old_password) - - @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_email(self): - """Test that we ratelimit /requestToken for the same email.""" - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email = "test1@example.com" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - def reset(ip): - client_secret = "foobar" - session_id = self._request_token(email, client_secret, ip) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - self._reset_password(new_password, session_id, client_secret) - - self.email_attempts.clear() - - # We expect to be able to make three requests before getting rate - # limited. - # - # We change IPs to ensure that we're not being ratelimited due to the - # same IP - reset("127.0.0.1") - reset("127.0.0.2") - reset("127.0.0.3") - - with self.assertRaises(HttpResponseException) as cm: - reset("127.0.0.4") - - self.assertEqual(cm.exception.code, 429) - - def test_basic_password_reset_canonicalise_email(self): - """Test basic password reset flow - Request password reset with different spelling - """ - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email_profile = "test@example.com" - email_passwort_reset = "TEST@EXAMPLE.COM" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email_profile, - validated_at=0, - added_at=0, - ) - ) - - client_secret = "foobar" - session_id = self._request_token(email_passwort_reset, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - self._reset_password(new_password, session_id, client_secret) - - # Assert we can log in with the new password - self.login("kermit", new_password) - - # Assert we can't log in with the old password - self.attempt_wrong_password_login("kermit", old_password) - - def test_cant_reset_password_without_clicking_link(self): - """Test that we do actually need to click the link in the email""" - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email = "test@example.com" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - client_secret = "foobar" - session_id = self._request_token(email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - - # Attempt to reset password without clicking the link - self._reset_password(new_password, session_id, client_secret, expected_code=401) - - # Assert we can log in with the old password - self.login("kermit", old_password) - - # Assert we can't log in with the new password - self.attempt_wrong_password_login("kermit", new_password) - - def test_no_valid_token(self): - """Test that we do actually need to request a token and can't just - make a session up. - """ - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email = "test@example.com" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - client_secret = "foobar" - session_id = "weasle" - - # Attempt to reset password without even requesting an email - self._reset_password(new_password, session_id, client_secret, expected_code=401) - - # Assert we can log in with the old password - self.login("kermit", old_password) - - # Assert we can't log in with the new password - self.attempt_wrong_password_login("kermit", new_password) - - @unittest.override_config({"request_token_inhibit_3pid_errors": True}) - def test_password_reset_bad_email_inhibit_error(self): - """Test that triggering a password reset with an email address that isn't bound - to an account doesn't leak the lack of binding for that address if configured - that way. - """ - self.register_user("kermit", "monkey") - self.login("kermit", "monkey") - - email = "test@example.com" - - client_secret = "foobar" - session_id = self._request_token(email, client_secret) - - self.assertIsNotNone(session_id) - - def _request_token(self, email, client_secret, ip="127.0.0.1"): - channel = self.make_request( - "POST", - b"account/password/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, - client_ip=ip, - ) - - if channel.code != 200: - raise HttpResponseException( - channel.code, - channel.result["reason"], - channel.result["body"], - ) - - return channel.json_body["sid"] - - def _validate_token(self, link): - # Remove the host - path = link.replace("https://example.com", "") - - # Load the password reset confirmation page - channel = make_request( - self.reactor, - FakeSite(self.submit_token_resource), - "GET", - path, - shorthand=False, - ) - - self.assertEquals(200, channel.code, channel.result) - - # Now POST to the same endpoint, mimicking the same behaviour as clicking the - # password reset confirm button - - # Confirm the password reset - channel = make_request( - self.reactor, - FakeSite(self.submit_token_resource), - "POST", - path, - content=b"", - shorthand=False, - content_is_form=True, - ) - self.assertEquals(200, channel.code, channel.result) - - def _get_link_from_email(self): - assert self.email_attempts, "No emails have been sent" - - raw_msg = self.email_attempts[-1].decode("UTF-8") - mail = Parser().parsestr(raw_msg) - - text = None - for part in mail.walk(): - if part.get_content_type() == "text/plain": - text = part.get_payload(decode=True).decode("UTF-8") - break - - if not text: - self.fail("Could not find text portion of email to parse") - - match = re.search(r"https://example.com\S+", text) - assert match, "Could not find link in email" - - return match.group(0) - - def _reset_password( - self, new_password, session_id, client_secret, expected_code=200 - ): - channel = self.make_request( - "POST", - b"account/password", - { - "new_password": new_password, - "auth": { - "type": LoginType.EMAIL_IDENTITY, - "threepid_creds": { - "client_secret": client_secret, - "sid": session_id, - }, - }, - }, - ) - self.assertEquals(expected_code, channel.code, channel.result) - - -class DeactivateTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - account.register_servlets, - room.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - return self.hs - - def test_deactivate_account(self): - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test") - - self.deactivate(user_id, tok) - - store = self.hs.get_datastore() - - # Check that the user has been marked as deactivated. - self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) - - # Check that this access token has been invalidated. - channel = self.make_request("GET", "account/whoami", access_token=tok) - self.assertEqual(channel.code, 401) - - def test_pending_invites(self): - """Tests that deactivating a user rejects every pending invite for them.""" - store = self.hs.get_datastore() - - inviter_id = self.register_user("inviter", "test") - inviter_tok = self.login("inviter", "test") - - invitee_id = self.register_user("invitee", "test") - invitee_tok = self.login("invitee", "test") - - # Make @inviter:test invite @invitee:test in a new room. - room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok) - self.helper.invite( - room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok - ) - - # Make sure the invite is here. - pending_invites = self.get_success( - store.get_invited_rooms_for_local_user(invitee_id) - ) - self.assertEqual(len(pending_invites), 1, pending_invites) - self.assertEqual(pending_invites[0].room_id, room_id, pending_invites) - - # Deactivate @invitee:test. - self.deactivate(invitee_id, invitee_tok) - - # Check that the invite isn't there anymore. - pending_invites = self.get_success( - store.get_invited_rooms_for_local_user(invitee_id) - ) - self.assertEqual(len(pending_invites), 0, pending_invites) - - # Check that the membership of @invitee:test in the room is now "leave". - memberships = self.get_success( - store.get_rooms_for_local_user_where_membership_is( - invitee_id, [Membership.LEAVE] - ) - ) - self.assertEqual(len(memberships), 1, memberships) - self.assertEqual(memberships[0].room_id, room_id, memberships) - - def deactivate(self, user_id, tok): - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "test", - }, - "erase": False, - } - ) - channel = self.make_request( - "POST", "account/deactivate", request_data, access_token=tok - ) - self.assertEqual(channel.code, 200) - - -class WhoamiTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - account.register_servlets, - register.register_servlets, - ] - - def test_GET_whoami(self): - device_id = "wouldgohere" - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test", device_id=device_id) - - whoami = self.whoami(tok) - self.assertEqual(whoami, {"user_id": user_id, "device_id": device_id}) - - def test_GET_whoami_appservices(self): - user_id = "@as:test" - as_token = "i_am_an_app_service" - - appservice = ApplicationService( - as_token, - self.hs.config.server_name, - id="1234", - namespaces={"users": [{"regex": user_id, "exclusive": True}]}, - sender=user_id, - ) - self.hs.get_datastore().services_cache.append(appservice) - - whoami = self.whoami(as_token) - self.assertEqual(whoami, {"user_id": user_id}) - self.assertFalse(hasattr(whoami, "device_id")) - - def whoami(self, tok): - channel = self.make_request("GET", "account/whoami", {}, access_token=tok) - self.assertEqual(channel.code, 200) - return channel.json_body - - -class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): - - servlets = [ - account.register_servlets, - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - - # Email config. - config["email"] = { - "enable_notifs": False, - "template_dir": os.path.abspath( - pkg_resources.resource_filename("synapse", "res/templates") - ), - "smtp_host": "127.0.0.1", - "smtp_port": 20, - "require_transport_security": False, - "smtp_user": None, - "smtp_pass": None, - "notif_from": "test@example.com", - } - config["public_baseurl"] = "https://example.com" - - self.hs = self.setup_test_homeserver(config=config) - - async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) - - self.email_attempts = [] - self.hs.get_send_email_handler()._sendmail = sendmail - - return self.hs - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.user_id = self.register_user("kermit", "test") - self.user_id_tok = self.login("kermit", "test") - self.email = "test@example.com" - self.url_3pid = b"account/3pid" - - def test_add_valid_email(self): - self.get_success(self._add_email(self.email, self.email)) - - def test_add_valid_email_second_time(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - self.email, - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) - ) - - def test_add_valid_email_second_time_canonicalise(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - "TEST@EXAMPLE.COM", - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) - ) - - def test_add_email_no_at(self): - self.get_success( - self._request_token_invalid_email( - "address-without-at.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) - ) - - def test_add_email_two_at(self): - self.get_success( - self._request_token_invalid_email( - "foo@foo@test.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) - ) - - def test_add_email_bad_format(self): - self.get_success( - self._request_token_invalid_email( - "user@bad.example.net@good.example.com", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) - ) - - def test_add_email_domain_to_lower(self): - self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) - - def test_add_email_domain_with_umlaut(self): - self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")) - - def test_add_email_address_casefold(self): - self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) - - def test_address_trim(self): - self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) - - @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_ip(self): - """Tests that adding emails is ratelimited by IP""" - - # We expect to be able to set three emails before getting ratelimited. - self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) - self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) - self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) - - with self.assertRaises(HttpResponseException) as cm: - self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) - - self.assertEqual(cm.exception.code, 429) - - def test_add_email_if_disabled(self): - """Test adding email to profile when doing so is disallowed""" - self.hs.config.enable_3pid_changes = False - - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_delete_email(self): - """Test deleting an email from profile""" - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=self.user_id, - medium="email", - address=self.email, - validated_at=0, - added_at=0, - ) - ) - - channel = self.make_request( - "POST", - b"account/3pid/delete", - {"medium": "email", "address": self.email}, - access_token=self.user_id_tok, - ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_delete_email_if_disabled(self): - """Test deleting an email from profile when disallowed""" - self.hs.config.enable_3pid_changes = False - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=self.user_id, - medium="email", - address=self.email, - validated_at=0, - added_at=0, - ) - ) - - channel = self.make_request( - "POST", - b"account/3pid/delete", - {"medium": "email", "address": self.email}, - access_token=self.user_id_tok, - ) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) - - def test_cant_add_email_without_clicking_link(self): - """Test that we do actually need to click the link in the email""" - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - - # Attempt to add email without clicking the link - channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_no_valid_token(self): - """Test that we do actually need to request a token and can't just - make a session up. - """ - client_secret = "foobar" - session_id = "weasle" - - # Attempt to add email without even requesting an email - channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - @override_config({"next_link_domain_whitelist": None}) - def test_next_link(self): - """Tests a valid next_link parameter value with no whitelist (good case)""" - self._request_token( - "something@example.com", - "some_secret", - next_link="https://example.com/a/good/site", - expect_code=200, - ) - - @override_config({"next_link_domain_whitelist": None}) - def test_next_link_exotic_protocol(self): - """Tests using a esoteric protocol as a next_link parameter value. - Someone may be hosting a client on IPFS etc. - """ - self._request_token( - "something@example.com", - "some_secret", - next_link="some-protocol://abcdefghijklmopqrstuvwxyz", - expect_code=200, - ) - - @override_config({"next_link_domain_whitelist": None}) - def test_next_link_file_uri(self): - """Tests next_link parameters cannot be file URI""" - # Attempt to use a next_link value that points to the local disk - self._request_token( - "something@example.com", - "some_secret", - next_link="file:///host/path", - expect_code=400, - ) - - @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) - def test_next_link_domain_whitelist(self): - """Tests next_link parameters must fit the whitelist if provided""" - - # Ensure not providing a next_link parameter still works - self._request_token( - "something@example.com", - "some_secret", - next_link=None, - expect_code=200, - ) - - self._request_token( - "something@example.com", - "some_secret", - next_link="https://example.com/some/good/page", - expect_code=200, - ) - - self._request_token( - "something@example.com", - "some_secret", - next_link="https://example.org/some/also/good/page", - expect_code=200, - ) - - self._request_token( - "something@example.com", - "some_secret", - next_link="https://bad.example.org/some/bad/page", - expect_code=400, - ) - - @override_config({"next_link_domain_whitelist": []}) - def test_empty_next_link_domain_whitelist(self): - """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially - disallowed - """ - self._request_token( - "something@example.com", - "some_secret", - next_link="https://example.com/a/page", - expect_code=400, - ) - - def _request_token( - self, - email: str, - client_secret: str, - next_link: Optional[str] = None, - expect_code: int = 200, - ) -> str: - """Request a validation token to add an email address to a user's account - - Args: - email: The email address to validate - client_secret: A secret string - next_link: A link to redirect the user to after validation - expect_code: Expected return code of the call - - Returns: - The ID of the new threepid validation session - """ - body = {"client_secret": client_secret, "email": email, "send_attempt": 1} - if next_link: - body["next_link"] = next_link - - channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - body, - ) - - if channel.code != expect_code: - raise HttpResponseException( - channel.code, - channel.result["reason"], - channel.result["body"], - ) - - return channel.json_body.get("sid") - - def _request_token_invalid_email( - self, - email, - expected_errcode, - expected_error, - client_secret="foobar", - ): - channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, - ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(expected_errcode, channel.json_body["errcode"]) - self.assertEqual(expected_error, channel.json_body["error"]) - - def _validate_token(self, link): - # Remove the host - path = link.replace("https://example.com", "") - - channel = self.make_request("GET", path, shorthand=False) - self.assertEquals(200, channel.code, channel.result) - - def _get_link_from_email(self): - assert self.email_attempts, "No emails have been sent" - - raw_msg = self.email_attempts[-1].decode("UTF-8") - mail = Parser().parsestr(raw_msg) - - text = None - for part in mail.walk(): - if part.get_content_type() == "text/plain": - text = part.get_payload(decode=True).decode("UTF-8") - break - - if not text: - self.fail("Could not find text portion of email to parse") - - match = re.search(r"https://example.com\S+", text) - assert match, "Could not find link in email" - - return match.group(0) - - def _add_email(self, request_email, expected_email): - """Test adding an email to profile""" - previous_email_attempts = len(self.email_attempts) - - client_secret = "foobar" - session_id = self._request_token(request_email, client_secret) - - self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) - link = self._get_link_from_email() - - self._validate_token(link) - - channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - - threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} - self.assertIn(expected_email, threepids) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py deleted file mode 100644 index cf5cfb910c..0000000000 --- a/tests/rest/client/v2_alpha/test_auth.py +++ /dev/null @@ -1,717 +0,0 @@ -# Copyright 2018 New Vector -# Copyright 2020-2021 The Matrix.org Foundation C.I.C -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional, Union - -from twisted.internet.defer import succeed - -import synapse.rest.admin -from synapse.api.constants import LoginType -from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client import account, auth, devices, login, register -from synapse.rest.synapse.client import build_synapse_client_resource_tree -from synapse.types import JsonDict, UserID - -from tests import unittest -from tests.handlers.test_oidc import HAS_OIDC -from tests.rest.client.v1.utils import TEST_OIDC_CONFIG -from tests.server import FakeChannel -from tests.unittest import override_config, skip_unless - - -class DummyRecaptchaChecker(UserInteractiveAuthChecker): - def __init__(self, hs): - super().__init__(hs) - self.recaptcha_attempts = [] - - def check_auth(self, authdict, clientip): - self.recaptcha_attempts.append((authdict, clientip)) - return succeed(True) - - -class FallbackAuthTests(unittest.HomeserverTestCase): - - servlets = [ - auth.register_servlets, - register.register_servlets, - ] - hijack_auth = False - - def make_homeserver(self, reactor, clock): - - config = self.default_config() - - config["enable_registration_captcha"] = True - config["recaptcha_public_key"] = "brokencake" - config["registrations_require_3pid"] = [] - - hs = self.setup_test_homeserver(config=config) - return hs - - def prepare(self, reactor, clock, hs): - self.recaptcha_checker = DummyRecaptchaChecker(hs) - auth_handler = hs.get_auth_handler() - auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - - def register(self, expected_response: int, body: JsonDict) -> FakeChannel: - """Make a register request.""" - channel = self.make_request("POST", "register", body) - - self.assertEqual(channel.code, expected_response) - return channel - - def recaptcha( - self, - session: str, - expected_post_response: int, - post_session: Optional[str] = None, - ) -> None: - """Get and respond to a fallback recaptcha. Returns the second request.""" - if post_session is None: - post_session = session - - channel = self.make_request( - "GET", "auth/m.login.recaptcha/fallback/web?session=" + session - ) - self.assertEqual(channel.code, 200) - - channel = self.make_request( - "POST", - "auth/m.login.recaptcha/fallback/web?session=" - + post_session - + "&g-recaptcha-response=a", - ) - self.assertEqual(channel.code, expected_post_response) - - # The recaptcha handler is called with the response given - attempts = self.recaptcha_checker.recaptcha_attempts - self.assertEqual(len(attempts), 1) - self.assertEqual(attempts[0][0]["response"], "a") - - def test_fallback_captcha(self): - """Ensure that fallback auth via a captcha works.""" - # Returns a 401 as per the spec - channel = self.register( - 401, - {"username": "user", "type": "m.login.password", "password": "bar"}, - ) - - # Grab the session - session = channel.json_body["session"] - # Assert our configured public key is being given - self.assertEqual( - channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" - ) - - # Complete the recaptcha step. - self.recaptcha(session, 200) - - # also complete the dummy auth - self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) - - # Now we should have fulfilled a complete auth flow, including - # the recaptcha fallback step, we can then send a - # request to the register API with the session in the authdict. - channel = self.register(200, {"auth": {"session": session}}) - - # We're given a registered user. - self.assertEqual(channel.json_body["user_id"], "@user:test") - - def test_complete_operation_unknown_session(self): - """ - Attempting to mark an invalid session as complete should error. - """ - # Make the initial request to register. (Later on a different password - # will be used.) - # Returns a 401 as per the spec - channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"} - ) - - # Grab the session - session = channel.json_body["session"] - # Assert our configured public key is being given - self.assertEqual( - channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" - ) - - # Attempt to complete the recaptcha step with an unknown session. - # This results in an error. - self.recaptcha(session, 400, session + "unknown") - - -class UIAuthTests(unittest.HomeserverTestCase): - servlets = [ - auth.register_servlets, - devices.register_servlets, - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - register.register_servlets, - ] - - def default_config(self): - config = super().default_config() - - # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns - # False, so synapse will see the requested uri as http://..., so using http in - # the public_baseurl stops Synapse trying to redirect to https. - config["public_baseurl"] = "http://synapse.test" - - if HAS_OIDC: - # we enable OIDC as a way of testing SSO flows - oidc_config = {} - oidc_config.update(TEST_OIDC_CONFIG) - oidc_config["allow_existing_users"] = True - config["oidc_config"] = oidc_config - - return config - - def create_resource_dict(self): - resource_dict = super().create_resource_dict() - resource_dict.update(build_synapse_client_resource_tree(self.hs)) - return resource_dict - - def prepare(self, reactor, clock, hs): - self.user_pass = "pass" - self.user = self.register_user("test", self.user_pass) - self.device_id = "dev1" - self.user_tok = self.login("test", self.user_pass, self.device_id) - - def delete_device( - self, - access_token: str, - device: str, - expected_response: int, - body: Union[bytes, JsonDict] = b"", - ) -> FakeChannel: - """Delete an individual device.""" - channel = self.make_request( - "DELETE", - "devices/" + device, - body, - access_token=access_token, - ) - - # Ensure the response is sane. - self.assertEqual(channel.code, expected_response) - - return channel - - def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel: - """Delete 1 or more devices.""" - # Note that this uses the delete_devices endpoint so that we can modify - # the payload half-way through some tests. - channel = self.make_request( - "POST", - "delete_devices", - body, - access_token=self.user_tok, - ) - - # Ensure the response is sane. - self.assertEqual(channel.code, expected_response) - - return channel - - def test_ui_auth(self): - """ - Test user interactive authentication outside of registration. - """ - # Attempt to delete this device. - # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) - - # Grab the session - session = channel.json_body["session"] - # Ensure that flows are what is expected. - self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) - - # Make another request providing the UI auth flow. - self.delete_device( - self.user_tok, - self.device_id, - 200, - { - "auth": { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": self.user}, - "password": self.user_pass, - "session": session, - }, - }, - ) - - def test_grandfathered_identifier(self): - """Check behaviour without "identifier" dict - - Synapse used to require clients to submit a "user" field for m.login.password - UIA - check that still works. - """ - - channel = self.delete_device(self.user_tok, self.device_id, 401) - session = channel.json_body["session"] - - # Make another request providing the UI auth flow. - self.delete_device( - self.user_tok, - self.device_id, - 200, - { - "auth": { - "type": "m.login.password", - "user": self.user, - "password": self.user_pass, - "session": session, - }, - }, - ) - - def test_can_change_body(self): - """ - The client dict can be modified during the user interactive authentication session. - - Note that it is not spec compliant to modify the client dict during a - user interactive authentication session, but many clients currently do. - - When Synapse is updated to be spec compliant, the call to re-use the - session ID should be rejected. - """ - # Create a second login. - self.login("test", self.user_pass, "dev2") - - # Attempt to delete the first device. - # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [self.device_id]}) - - # Grab the session - session = channel.json_body["session"] - # Ensure that flows are what is expected. - self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) - - # Make another request providing the UI auth flow, but try to delete the - # second device. - self.delete_devices( - 200, - { - "devices": ["dev2"], - "auth": { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": self.user}, - "password": self.user_pass, - "session": session, - }, - }, - ) - - def test_cannot_change_uri(self): - """ - The initial requested URI cannot be modified during the user interactive authentication session. - """ - # Create a second login. - self.login("test", self.user_pass, "dev2") - - # Attempt to delete the first device. - # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) - - # Grab the session - session = channel.json_body["session"] - # Ensure that flows are what is expected. - self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) - - # Make another request providing the UI auth flow, but try to delete the - # second device. This results in an error. - # - # This makes use of the fact that the device ID is embedded into the URL. - self.delete_device( - self.user_tok, - "dev2", - 403, - { - "auth": { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": self.user}, - "password": self.user_pass, - "session": session, - }, - }, - ) - - @unittest.override_config({"ui_auth": {"session_timeout": "5s"}}) - def test_can_reuse_session(self): - """ - The session can be reused if configured. - - Compare to test_cannot_change_uri. - """ - # Create a second and third login. - self.login("test", self.user_pass, "dev2") - self.login("test", self.user_pass, "dev3") - - # Attempt to delete a device. This works since the user just logged in. - self.delete_device(self.user_tok, "dev2", 200) - - # Move the clock forward past the validation timeout. - self.reactor.advance(6) - - # Deleting another devices throws the user into UI auth. - channel = self.delete_device(self.user_tok, "dev3", 401) - - # Grab the session - session = channel.json_body["session"] - # Ensure that flows are what is expected. - self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) - - # Make another request providing the UI auth flow. - self.delete_device( - self.user_tok, - "dev3", - 200, - { - "auth": { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": self.user}, - "password": self.user_pass, - "session": session, - }, - }, - ) - - # Make another request, but try to delete the first device. This works - # due to re-using the previous session. - # - # Note that *no auth* information is provided, not even a session iD! - self.delete_device(self.user_tok, self.device_id, 200) - - @skip_unless(HAS_OIDC, "requires OIDC") - @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_via_sso(self): - """Test a successful UI Auth flow via SSO - - This includes: - * hitting the UIA SSO redirect endpoint - * checking it serves a confirmation page which links to the OIDC provider - * calling back to the synapse oidc callback - * checking that the original operation succeeds - """ - - # log the user in - remote_user_id = UserID.from_string(self.user).localpart - login_resp = self.helper.login_via_oidc(remote_user_id) - self.assertEqual(login_resp["user_id"], self.user) - - # initiate a UI Auth process by attempting to delete the device - channel = self.delete_device(self.user_tok, self.device_id, 401) - - # check that SSO is offered - flows = channel.json_body["flows"] - self.assertIn({"stages": ["m.login.sso"]}, flows) - - # run the UIA-via-SSO flow - session_id = channel.json_body["session"] - channel = self.helper.auth_via_oidc( - {"sub": remote_user_id}, ui_auth_session_id=session_id - ) - - # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) - - # and now the delete request should succeed. - self.delete_device( - self.user_tok, - self.device_id, - 200, - body={"auth": {"session": session_id}}, - ) - - @skip_unless(HAS_OIDC, "requires OIDC") - @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_does_not_offer_password_for_sso_user(self): - login_resp = self.helper.login_via_oidc("username") - user_tok = login_resp["access_token"] - device_id = login_resp["device_id"] - - # now call the device deletion API: we should get the option to auth with SSO - # and not password. - channel = self.delete_device(user_tok, device_id, 401) - - flows = channel.json_body["flows"] - self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) - - def test_does_not_offer_sso_for_password_user(self): - channel = self.delete_device(self.user_tok, self.device_id, 401) - - flows = channel.json_body["flows"] - self.assertEqual(flows, [{"stages": ["m.login.password"]}]) - - @skip_unless(HAS_OIDC, "requires OIDC") - @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_offers_both_flows_for_upgraded_user(self): - """A user that had a password and then logged in with SSO should get both flows""" - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) - self.assertEqual(login_resp["user_id"], self.user) - - channel = self.delete_device(self.user_tok, self.device_id, 401) - - flows = channel.json_body["flows"] - # we have no particular expectations of ordering here - self.assertIn({"stages": ["m.login.password"]}, flows) - self.assertIn({"stages": ["m.login.sso"]}, flows) - self.assertEqual(len(flows), 2) - - @skip_unless(HAS_OIDC, "requires OIDC") - @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_fails_for_incorrect_sso_user(self): - """If the user tries to authenticate with the wrong SSO user, they get an error""" - # log the user in - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) - self.assertEqual(login_resp["user_id"], self.user) - - # start a UI Auth flow by attempting to delete a device - channel = self.delete_device(self.user_tok, self.device_id, 401) - - flows = channel.json_body["flows"] - self.assertIn({"stages": ["m.login.sso"]}, flows) - session_id = channel.json_body["session"] - - # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc( - {"sub": "wrong_user"}, ui_auth_session_id=session_id - ) - - # that should return a failure message - self.assertSubstring("We were unable to validate", channel.text_body) - - # ... and the delete op should now fail with a 403 - self.delete_device( - self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} - ) - - -class RefreshAuthTests(unittest.HomeserverTestCase): - servlets = [ - auth.register_servlets, - account.register_servlets, - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - register.register_servlets, - ] - hijack_auth = False - - def prepare(self, reactor, clock, hs): - self.user_pass = "pass" - self.user = self.register_user("test", self.user_pass) - - def test_login_issue_refresh_token(self): - """ - A login response should include a refresh_token only if asked. - """ - # Test login - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} - - login_without_refresh = self.make_request( - "POST", "/_matrix/client/r0/login", body - ) - self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) - self.assertNotIn("refresh_token", login_without_refresh.json_body) - - login_with_refresh = self.make_request( - "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", - body, - ) - self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) - self.assertIn("refresh_token", login_with_refresh.json_body) - self.assertIn("expires_in_ms", login_with_refresh.json_body) - - def test_register_issue_refresh_token(self): - """ - A register response should include a refresh_token only if asked. - """ - register_without_refresh = self.make_request( - "POST", - "/_matrix/client/r0/register", - { - "username": "test2", - "password": self.user_pass, - "auth": {"type": LoginType.DUMMY}, - }, - ) - self.assertEqual( - register_without_refresh.code, 200, register_without_refresh.result - ) - self.assertNotIn("refresh_token", register_without_refresh.json_body) - - register_with_refresh = self.make_request( - "POST", - "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true", - { - "username": "test3", - "password": self.user_pass, - "auth": {"type": LoginType.DUMMY}, - }, - ) - self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) - self.assertIn("refresh_token", register_with_refresh.json_body) - self.assertIn("expires_in_ms", register_with_refresh.json_body) - - def test_token_refresh(self): - """ - A refresh token can be used to issue a new access token. - """ - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} - login_response = self.make_request( - "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", - body, - ) - self.assertEqual(login_response.code, 200, login_response.result) - - refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) - self.assertIn("access_token", refresh_response.json_body) - self.assertIn("refresh_token", refresh_response.json_body) - self.assertIn("expires_in_ms", refresh_response.json_body) - - # The access and refresh tokens should be different from the original ones after refresh - self.assertNotEqual( - login_response.json_body["access_token"], - refresh_response.json_body["access_token"], - ) - self.assertNotEqual( - login_response.json_body["refresh_token"], - refresh_response.json_body["refresh_token"], - ) - - @override_config({"access_token_lifetime": "1m"}) - def test_refresh_token_expiration(self): - """ - The access token should have some time as specified in the config. - """ - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} - login_response = self.make_request( - "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", - body, - ) - self.assertEqual(login_response.code, 200, login_response.result) - self.assertApproximates( - login_response.json_body["expires_in_ms"], 60 * 1000, 100 - ) - - refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) - self.assertApproximates( - refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 - ) - - def test_refresh_token_invalidation(self): - """Refresh tokens are invalidated after first use of the next token. - - A refresh token is considered invalid if: - - it was already used at least once - - and either - - the next access token was used - - the next refresh token was used - - The chain of tokens goes like this: - - login -|-> first_refresh -> third_refresh (fails) - |-> second_refresh -> fifth_refresh - |-> fourth_refresh (fails) - """ - - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} - login_response = self.make_request( - "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", - body, - ) - self.assertEqual(login_response.code, 200, login_response.result) - - # This first refresh should work properly - first_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual( - first_refresh_response.code, 200, first_refresh_response.result - ) - - # This one as well, since the token in the first one was never used - second_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual( - second_refresh_response.code, 200, second_refresh_response.result - ) - - # This one should not, since the token from the first refresh is not valid anymore - third_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": first_refresh_response.json_body["refresh_token"]}, - ) - self.assertEqual( - third_refresh_response.code, 401, third_refresh_response.result - ) - - # The associated access token should also be invalid - whoami_response = self.make_request( - "GET", - "/_matrix/client/r0/account/whoami", - access_token=first_refresh_response.json_body["access_token"], - ) - self.assertEqual(whoami_response.code, 401, whoami_response.result) - - # But all other tokens should work (they will expire after some time) - for access_token in [ - second_refresh_response.json_body["access_token"], - login_response.json_body["access_token"], - ]: - whoami_response = self.make_request( - "GET", "/_matrix/client/r0/account/whoami", access_token=access_token - ) - self.assertEqual(whoami_response.code, 200, whoami_response.result) - - # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail - fourth_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual( - fourth_refresh_response.code, 403, fourth_refresh_response.result - ) - - # But refreshing from the last valid refresh token still works - fifth_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": second_refresh_response.json_body["refresh_token"]}, - ) - self.assertEqual( - fifth_refresh_response.code, 200, fifth_refresh_response.result - ) diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py deleted file mode 100644 index ac31e5ceaf..0000000000 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import synapse.rest.admin -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.rest.client import capabilities, login - -from tests import unittest -from tests.unittest import override_config - - -class CapabilitiesTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - capabilities.register_servlets, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.url = b"/_matrix/client/r0/capabilities" - hs = self.setup_test_homeserver() - self.config = hs.config - self.auth_handler = hs.get_auth_handler() - return hs - - def prepare(self, reactor, clock, hs): - self.localpart = "user" - self.password = "pass" - self.user = self.register_user(self.localpart, self.password) - - def test_check_auth_required(self): - channel = self.make_request("GET", self.url) - - self.assertEqual(channel.code, 401) - - def test_get_room_version_capabilities(self): - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - for room_version in capabilities["m.room_versions"]["available"].keys(): - self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) - - self.assertEqual( - self.config.default_room_version.identifier, - capabilities["m.room_versions"]["default"], - ) - - def test_get_change_password_capabilities_password_login(self): - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertTrue(capabilities["m.change_password"]["enabled"]) - - @override_config({"password_config": {"localdb_enabled": False}}) - def test_get_change_password_capabilities_localdb_disabled(self): - access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( - self.user, device_id=None, valid_until_ms=None - ) - ) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["m.change_password"]["enabled"]) - - @override_config({"password_config": {"enabled": False}}) - def test_get_change_password_capabilities_password_disabled(self): - access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( - self.user, device_id=None, valid_until_ms=None - ) - ) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["m.change_password"]["enabled"]) - - def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self): - """Test that per default msc3283 is disabled server returns `m.change_password`.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities) - self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities) - self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities) - - @override_config({"experimental_features": {"msc3283_enabled": True}}) - def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self): - """Test if msc3283 is enabled server returns capabilities.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) - - @override_config( - { - "enable_set_displayname": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) - def test_get_set_displayname_capabilities_displayname_disabled(self): - """Test if set displayname is disabled that the server responds it.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) - - @override_config( - { - "enable_set_avatar_url": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) - def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): - """Test if set avatar_url is disabled that the server responds it.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) - - @override_config( - { - "enable_3pid_changes": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) - def test_change_3pid_capabilities_3pid_disabled(self): - """Test if change 3pid is disabled that the server responds it.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) - - def test_get_does_not_include_msc3244_fields_by_default(self): - access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( - self.user, device_id=None, valid_until_ms=None - ) - ) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertNotIn( - "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] - ) - - @override_config({"experimental_features": {"msc3244_enabled": True}}) - def test_get_does_include_msc3244_fields_when_enabled(self): - access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( - self.user, device_id=None, valid_until_ms=None - ) - ) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - for details in capabilities["m.room_versions"][ - "org.matrix.msc3244.room_capabilities" - ].values(): - if details["preferred"] is not None: - self.assertTrue( - details["preferred"] in KNOWN_ROOM_VERSIONS, - str(details["preferred"]), - ) - - self.assertGreater(len(details["support"]), 0) - for room_version in details["support"]: - self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version)) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py deleted file mode 100644 index 475c6bed3d..0000000000 --- a/tests/rest/client/v2_alpha/test_filter.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer - -from synapse.api.errors import Codes -from synapse.rest.client import filter - -from tests import unittest - -PATH_PREFIX = "/_matrix/client/v2_alpha" - - -class FilterTestCase(unittest.HomeserverTestCase): - - user_id = "@apple:test" - hijack_auth = True - EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} - EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' - servlets = [filter.register_servlets] - - def prepare(self, reactor, clock, hs): - self.filtering = hs.get_filtering() - self.store = hs.get_datastore() - - def test_add_filter(self): - channel = self.make_request( - "POST", - "/_matrix/client/r0/user/%s/filter" % (self.user_id), - self.EXAMPLE_FILTER_JSON, - ) - - self.assertEqual(channel.result["code"], b"200") - self.assertEqual(channel.json_body, {"filter_id": "0"}) - filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) - self.pump() - self.assertEquals(filter.result, self.EXAMPLE_FILTER) - - def test_add_filter_for_other_user(self): - channel = self.make_request( - "POST", - "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), - self.EXAMPLE_FILTER_JSON, - ) - - self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) - - def test_add_filter_non_local_user(self): - _is_mine = self.hs.is_mine - self.hs.is_mine = lambda target_user: False - channel = self.make_request( - "POST", - "/_matrix/client/r0/user/%s/filter" % (self.user_id), - self.EXAMPLE_FILTER_JSON, - ) - - self.hs.is_mine = _is_mine - self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) - - def test_get_filter(self): - filter_id = defer.ensureDeferred( - self.filtering.add_user_filter( - user_localpart="apple", user_filter=self.EXAMPLE_FILTER - ) - ) - self.reactor.advance(1) - filter_id = filter_id.result - channel = self.make_request( - "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) - ) - - self.assertEqual(channel.result["code"], b"200") - self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) - - def test_get_filter_non_existant(self): - channel = self.make_request( - "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) - ) - - self.assertEqual(channel.result["code"], b"404") - self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) - - # Currently invalid params do not have an appropriate errcode - # in errors.py - def test_get_filter_invalid_id(self): - channel = self.make_request( - "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) - ) - - self.assertEqual(channel.result["code"], b"400") - - # No ID also returns an invalid_id error - def test_get_filter_no_id(self): - channel = self.make_request( - "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) - ) - - self.assertEqual(channel.result["code"], b"400") diff --git a/tests/rest/client/v2_alpha/test_keys.py b/tests/rest/client/v2_alpha/test_keys.py deleted file mode 100644 index d7fa635eae..0000000000 --- a/tests/rest/client/v2_alpha/test_keys.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License - -from http import HTTPStatus - -from synapse.api.errors import Codes -from synapse.rest import admin -from synapse.rest.client import keys, login - -from tests import unittest - - -class KeyQueryTestCase(unittest.HomeserverTestCase): - servlets = [ - keys.register_servlets, - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - ] - - def test_rejects_device_id_ice_key_outside_of_list(self): - self.register_user("alice", "wonderland") - alice_token = self.login("alice", "wonderland") - bob = self.register_user("bob", "uncle") - channel = self.make_request( - "POST", - "/_matrix/client/r0/keys/query", - { - "device_keys": { - bob: "device_id1", - }, - }, - alice_token, - ) - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.BAD_JSON, - channel.result, - ) - - def test_rejects_device_key_given_as_map_to_bool(self): - self.register_user("alice", "wonderland") - alice_token = self.login("alice", "wonderland") - bob = self.register_user("bob", "uncle") - channel = self.make_request( - "POST", - "/_matrix/client/r0/keys/query", - { - "device_keys": { - bob: { - "device_id1": True, - }, - }, - }, - alice_token, - ) - - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.BAD_JSON, - channel.result, - ) - - def test_requires_device_key(self): - """`device_keys` is required. We should complain if it's missing.""" - self.register_user("alice", "wonderland") - alice_token = self.login("alice", "wonderland") - channel = self.make_request( - "POST", - "/_matrix/client/r0/keys/query", - {}, - alice_token, - ) - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.BAD_JSON, - channel.result, - ) diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py deleted file mode 100644 index 3cf5871899..0000000000 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from synapse.api.constants import LoginType -from synapse.api.errors import Codes -from synapse.rest import admin -from synapse.rest.client import account, login, password_policy, register - -from tests import unittest - - -class PasswordPolicyTestCase(unittest.HomeserverTestCase): - """Tests the password policy feature and its compliance with MSC2000. - - When validating a password, Synapse does the necessary checks in this order: - - 1. Password is long enough - 2. Password contains digit(s) - 3. Password contains symbol(s) - 4. Password contains uppercase letter(s) - 5. Password contains lowercase letter(s) - - For each test below that checks whether a password triggers the right error code, - that test provides a password good enough to pass the previous tests, but not the - one it is currently testing (nor any test that comes afterward). - """ - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - register.register_servlets, - password_policy.register_servlets, - account.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.register_url = "/_matrix/client/r0/register" - self.policy = { - "enabled": True, - "minimum_length": 10, - "require_digit": True, - "require_symbol": True, - "require_lowercase": True, - "require_uppercase": True, - } - - config = self.default_config() - config["password_config"] = { - "policy": self.policy, - } - - hs = self.setup_test_homeserver(config=config) - return hs - - def test_get_policy(self): - """Tests if the /password_policy endpoint returns the configured policy.""" - - channel = self.make_request("GET", "/_matrix/client/r0/password_policy") - - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual( - channel.json_body, - { - "m.minimum_length": 10, - "m.require_digit": True, - "m.require_symbol": True, - "m.require_lowercase": True, - "m.require_uppercase": True, - }, - channel.result, - ) - - def test_password_too_short(self): - request_data = json.dumps({"username": "kermit", "password": "shorty"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_TOO_SHORT, - channel.result, - ) - - def test_password_no_digit(self): - request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_DIGIT, - channel.result, - ) - - def test_password_no_symbol(self): - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_SYMBOL, - channel.result, - ) - - def test_password_no_uppercase(self): - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_UPPERCASE, - channel.result, - ) - - def test_password_no_lowercase(self): - request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_LOWERCASE, - channel.result, - ) - - def test_password_compliant(self): - request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) - channel = self.make_request("POST", self.register_url, request_data) - - # Getting a 401 here means the password has passed validation and the server has - # responded with a list of registration flows. - self.assertEqual(channel.code, 401, channel.result) - - def test_password_change(self): - """This doesn't test every possible use case, only that hitting /account/password - triggers the password validation code. - """ - compliant_password = "C0mpl!antpassword" - not_compliant_password = "notcompliantpassword" - - user_id = self.register_user("kermit", compliant_password) - tok = self.login("kermit", compliant_password) - - request_data = json.dumps( - { - "new_password": not_compliant_password, - "auth": { - "password": compliant_password, - "type": LoginType.PASSWORD, - "user": user_id, - }, - } - ) - channel = self.make_request( - "POST", - "/_matrix/client/r0/account/password", - request_data, - access_token=tok, - ) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py deleted file mode 100644 index fecda037a5..0000000000 --- a/tests/rest/client/v2_alpha/test_register.py +++ /dev/null @@ -1,746 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import datetime -import json -import os - -import pkg_resources - -import synapse.rest.admin -from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType -from synapse.api.errors import Codes -from synapse.appservice import ApplicationService -from synapse.rest.client import account, account_validity, login, logout, register, sync - -from tests import unittest -from tests.unittest import override_config - - -class RegisterRestServletTestCase(unittest.HomeserverTestCase): - - servlets = [ - login.register_servlets, - register.register_servlets, - synapse.rest.admin.register_servlets, - ] - url = b"/_matrix/client/r0/register" - - def default_config(self): - config = super().default_config() - config["allow_guest_access"] = True - return config - - def test_POST_appservice_registration_valid(self): - user_id = "@as_user_kermit:test" - as_token = "i_am_an_app_service" - - appservice = ApplicationService( - as_token, - self.hs.config.server_name, - id="1234", - namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, - sender="@as:test", - ) - - self.hs.get_datastore().services_cache.append(appservice) - request_data = json.dumps( - {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) - - channel = self.make_request( - b"POST", self.url + b"?access_token=i_am_an_app_service", request_data - ) - - self.assertEquals(channel.result["code"], b"200", channel.result) - det_data = {"user_id": user_id, "home_server": self.hs.hostname} - self.assertDictContainsSubset(det_data, channel.json_body) - - def test_POST_appservice_registration_no_type(self): - as_token = "i_am_an_app_service" - - appservice = ApplicationService( - as_token, - self.hs.config.server_name, - id="1234", - namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, - sender="@as:test", - ) - - self.hs.get_datastore().services_cache.append(appservice) - request_data = json.dumps({"username": "as_user_kermit"}) - - channel = self.make_request( - b"POST", self.url + b"?access_token=i_am_an_app_service", request_data - ) - - self.assertEquals(channel.result["code"], b"400", channel.result) - - def test_POST_appservice_registration_invalid(self): - self.appservice = None # no application service exists - request_data = json.dumps( - {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) - channel = self.make_request( - b"POST", self.url + b"?access_token=i_am_an_app_service", request_data - ) - - self.assertEquals(channel.result["code"], b"401", channel.result) - - def test_POST_bad_password(self): - request_data = json.dumps({"username": "kermit", "password": 666}) - channel = self.make_request(b"POST", self.url, request_data) - - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid password") - - def test_POST_bad_username(self): - request_data = json.dumps({"username": 777, "password": "monkey"}) - channel = self.make_request(b"POST", self.url, request_data) - - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid username") - - def test_POST_user_valid(self): - user_id = "@kermit:test" - device_id = "frogfone" - params = { - "username": "kermit", - "password": "monkey", - "device_id": device_id, - "auth": {"type": LoginType.DUMMY}, - } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) - - det_data = { - "user_id": user_id, - "home_server": self.hs.hostname, - "device_id": device_id, - } - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertDictContainsSubset(det_data, channel.json_body) - - @override_config({"enable_registration": False}) - def test_POST_disabled_registration(self): - request_data = json.dumps({"username": "kermit", "password": "monkey"}) - self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) - - channel = self.make_request(b"POST", self.url, request_data) - - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Registration has been disabled") - self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") - - def test_POST_guest_registration(self): - self.hs.config.macaroon_secret_key = "test" - self.hs.config.allow_guest_access = True - - channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - - det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertDictContainsSubset(det_data, channel.json_body) - - def test_POST_disabled_guest_registration(self): - self.hs.config.allow_guest_access = False - - channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Guest access is disabled") - - @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting_guest(self): - for i in range(0, 6): - url = self.url + b"?kind=guest" - channel = self.make_request(b"POST", url, b"{}") - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - - channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - - self.assertEquals(channel.result["code"], b"200", channel.result) - - @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting(self): - for i in range(0, 6): - params = { - "username": "kermit" + str(i), - "password": "monkey", - "device_id": "frogfone", - "auth": {"type": LoginType.DUMMY}, - } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - - channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_advertised_flows(self): - channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) - flows = channel.json_body["flows"] - - # with the stock config, we only expect the dummy flow - self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows)) - - @unittest.override_config( - { - "public_baseurl": "https://test_server", - "enable_registration_captcha": True, - "user_consent": { - "version": "1", - "template_dir": "/", - "require_at_registration": True, - }, - "account_threepid_delegates": { - "email": "https://id_server", - "msisdn": "https://id_server", - }, - } - ) - def test_advertised_flows_captcha_and_terms_and_3pids(self): - channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) - flows = channel.json_body["flows"] - - self.assertCountEqual( - [ - ["m.login.recaptcha", "m.login.terms", "m.login.dummy"], - ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"], - ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"], - [ - "m.login.recaptcha", - "m.login.terms", - "m.login.msisdn", - "m.login.email.identity", - ], - ], - (f["stages"] for f in flows), - ) - - @unittest.override_config( - { - "public_baseurl": "https://test_server", - "registrations_require_3pid": ["email"], - "disable_msisdn_registration": True, - "email": { - "smtp_host": "mail_server", - "smtp_port": 2525, - "notif_from": "sender@host", - }, - } - ) - def test_advertised_flows_no_msisdn_email_required(self): - channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) - flows = channel.json_body["flows"] - - # with the stock config, we expect all four combinations of 3pid - self.assertCountEqual( - [["m.login.email.identity"]], (f["stages"] for f in flows) - ) - - @unittest.override_config( - { - "request_token_inhibit_3pid_errors": True, - "public_baseurl": "https://test_server", - "email": { - "smtp_host": "mail_server", - "smtp_port": 2525, - "notif_from": "sender@host", - }, - } - ) - def test_request_token_existing_email_inhibit_error(self): - """Test that requesting a token via this endpoint doesn't leak existing - associations if configured that way. - """ - user_id = self.register_user("kermit", "monkey") - self.login("kermit", "monkey") - - email = "test@example.com" - - # Add a threepid - self.get_success( - self.hs.get_datastore().user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - channel = self.make_request( - "POST", - b"register/email/requestToken", - {"client_secret": "foobar", "email": email, "send_attempt": 1}, - ) - self.assertEquals(200, channel.code, channel.result) - - self.assertIsNotNone(channel.json_body.get("sid")) - - @unittest.override_config( - { - "public_baseurl": "https://test_server", - "email": { - "smtp_host": "mail_server", - "smtp_port": 2525, - "notif_from": "sender@host", - }, - } - ) - def test_reject_invalid_email(self): - """Check that bad emails are rejected""" - - # Test for email with multiple @ - channel = self.make_request( - "POST", - b"register/email/requestToken", - {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, - ) - self.assertEquals(400, channel.code, channel.result) - # Check error to ensure that we're not erroring due to a bug in the test. - self.assertEquals( - channel.json_body, - {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, - ) - - # Test for email with no @ - channel = self.make_request( - "POST", - b"register/email/requestToken", - {"client_secret": "foobar", "email": "email", "send_attempt": 1}, - ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( - channel.json_body, - {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, - ) - - # Test for super long email - email = "a@" + "a" * 1000 - channel = self.make_request( - "POST", - b"register/email/requestToken", - {"client_secret": "foobar", "email": email, "send_attempt": 1}, - ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( - channel.json_body, - {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, - ) - - -class AccountValidityTestCase(unittest.HomeserverTestCase): - - servlets = [ - register.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - sync.register_servlets, - logout.register_servlets, - account_validity.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - # Test for account expiring after a week. - config["enable_registration"] = True - config["account_validity"] = { - "enabled": True, - "period": 604800000, # Time in ms for 1 week - } - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def test_validity_period(self): - self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - # The specific endpoint doesn't matter, all we need is an authenticated - # endpoint. - channel = self.make_request(b"GET", "/sync", access_token=tok) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) - - channel = self.make_request(b"GET", "/sync", access_token=tok) - - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( - channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result - ) - - def test_manual_renewal(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) - - # If we register the admin user at the beginning of the test, it will - # expire at the same time as the normal user and the renewal request - # will be denied. - self.register_user("admin", "adminpassword", admin=True) - admin_tok = self.login("admin", "adminpassword") - - url = "/_synapse/admin/v1/account_validity/validity" - params = {"user_id": user_id} - request_data = json.dumps(params) - channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # The specific endpoint doesn't matter, all we need is an authenticated - # endpoint. - channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_manual_expire(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - self.register_user("admin", "adminpassword", admin=True) - admin_tok = self.login("admin", "adminpassword") - - url = "/_synapse/admin/v1/account_validity/validity" - params = { - "user_id": user_id, - "expiration_ts": 0, - "enable_renewal_emails": False, - } - request_data = json.dumps(params) - channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # The specific endpoint doesn't matter, all we need is an authenticated - # endpoint. - channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( - channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result - ) - - def test_logging_out_expired_user(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - self.register_user("admin", "adminpassword", admin=True) - admin_tok = self.login("admin", "adminpassword") - - url = "/_synapse/admin/v1/account_validity/validity" - params = { - "user_id": user_id, - "expiration_ts": 0, - "enable_renewal_emails": False, - } - request_data = json.dumps(params) - channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Try to log the user out - channel = self.make_request(b"POST", "/logout", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Log the user in again (allowed for expired accounts) - tok = self.login("kermit", "monkey") - - # Try to log out all of the user's sessions - channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - -class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): - - servlets = [ - register.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - sync.register_servlets, - account_validity.register_servlets, - account.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - - # Test for account expiring after a week and renewal emails being sent 2 - # days before expiry. - config["enable_registration"] = True - config["account_validity"] = { - "enabled": True, - "period": 604800000, # Time in ms for 1 week - "renew_at": 172800000, # Time in ms for 2 days - "renew_by_email_enabled": True, - "renew_email_subject": "Renew your account", - "account_renewed_html_path": "account_renewed.html", - "invalid_token_html_path": "invalid_token.html", - } - - # Email config. - - config["email"] = { - "enable_notifs": True, - "template_dir": os.path.abspath( - pkg_resources.resource_filename("synapse", "res/templates") - ), - "expiry_template_html": "notice_expiry.html", - "expiry_template_text": "notice_expiry.txt", - "notif_template_html": "notif_mail.html", - "notif_template_text": "notif_mail.txt", - "smtp_host": "127.0.0.1", - "smtp_port": 20, - "require_transport_security": False, - "smtp_user": None, - "smtp_pass": None, - "notif_from": "test@example.com", - } - config["public_baseurl"] = "aaa" - - self.hs = self.setup_test_homeserver(config=config) - - async def sendmail(*args, **kwargs): - self.email_attempts.append((args, kwargs)) - - self.email_attempts = [] - self.hs.get_send_email_handler()._sendmail = sendmail - - self.store = self.hs.get_datastore() - - return self.hs - - def test_renewal_email(self): - self.email_attempts = [] - - (user_id, tok) = self.create_user() - - # Move 5 days forward. This should trigger a renewal email to be sent. - self.reactor.advance(datetime.timedelta(days=5).total_seconds()) - self.assertEqual(len(self.email_attempts), 1) - - # Retrieving the URL from the email is too much pain for now, so we - # retrieve the token from the DB. - renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) - url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token - channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Check that we're getting HTML back. - content_type = channel.headers.getRawHeaders(b"Content-Type") - self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) - - # Check that the HTML we're getting is the one we expect on a successful renewal. - expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id)) - expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render( - expiration_ts=expiration_ts - ) - self.assertEqual( - channel.result["body"], expected_html.encode("utf8"), channel.result - ) - - # Move 1 day forward. Try to renew with the same token again. - url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token - channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Check that we're getting HTML back. - content_type = channel.headers.getRawHeaders(b"Content-Type") - self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) - - # Check that the HTML we're getting is the one we expect when reusing a - # token. The account expiration date should not have changed. - expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render( - expiration_ts=expiration_ts - ) - self.assertEqual( - channel.result["body"], expected_html.encode("utf8"), channel.result - ) - - # Move 3 days forward. If the renewal failed, every authed request with - # our access token should be denied from now, otherwise they should - # succeed. - self.reactor.advance(datetime.timedelta(days=3).total_seconds()) - channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_renewal_invalid_token(self): - # Hit the renewal endpoint with an invalid token and check that it behaves as - # expected, i.e. that it responds with 404 Not Found and the correct HTML. - url = "/_matrix/client/unstable/account_validity/renew?token=123" - channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"404", channel.result) - - # Check that we're getting HTML back. - content_type = channel.headers.getRawHeaders(b"Content-Type") - self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) - - # Check that the HTML we're getting is the one we expect when using an - # invalid/unknown token. - expected_html = ( - self.hs.config.account_validity.account_validity_invalid_token_template.render() - ) - self.assertEqual( - channel.result["body"], expected_html.encode("utf8"), channel.result - ) - - def test_manual_email_send(self): - self.email_attempts = [] - - (user_id, tok) = self.create_user() - channel = self.make_request( - b"POST", - "/_matrix/client/unstable/account_validity/send_mail", - access_token=tok, - ) - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.assertEqual(len(self.email_attempts), 1) - - def test_deactivated_user(self): - self.email_attempts = [] - - (user_id, tok) = self.create_user() - - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "monkey", - }, - "erase": False, - } - ) - channel = self.make_request( - "POST", "account/deactivate", request_data, access_token=tok - ) - self.assertEqual(channel.code, 200) - - self.reactor.advance(datetime.timedelta(days=8).total_seconds()) - - self.assertEqual(len(self.email_attempts), 0) - - def create_user(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - # We need to manually add an email address otherwise the handler will do - # nothing. - now = self.hs.get_clock().time_msec() - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address="kermit@example.com", - validated_at=now, - added_at=now, - ) - ) - return user_id, tok - - def test_manual_email_send_expired_account(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - # We need to manually add an email address otherwise the handler will do - # nothing. - now = self.hs.get_clock().time_msec() - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address="kermit@example.com", - validated_at=now, - added_at=now, - ) - ) - - # Make the account expire. - self.reactor.advance(datetime.timedelta(days=8).total_seconds()) - - # Ignore all emails sent by the automatic background task and only focus on the - # ones sent manually. - self.email_attempts = [] - - # Test that we're still able to manually trigger a mail to be sent. - channel = self.make_request( - b"POST", - "/_matrix/client/unstable/account_validity/send_mail", - access_token=tok, - ) - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.assertEqual(len(self.email_attempts), 1) - - -class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): - - servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] - - def make_homeserver(self, reactor, clock): - self.validity_period = 10 - self.max_delta = self.validity_period * 10.0 / 100.0 - - config = self.default_config() - - config["enable_registration"] = True - config["account_validity"] = {"enabled": False} - - self.hs = self.setup_test_homeserver(config=config) - - # We need to set these directly, instead of in the homeserver config dict above. - # This is due to account validity-related config options not being read by - # Synapse when account_validity.enabled is False. - self.hs.get_datastore()._account_validity_period = self.validity_period - self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta - - self.store = self.hs.get_datastore() - - return self.hs - - def test_background_job(self): - """ - Tests the same thing as test_background_job, except that it sets the - startup_job_max_delta parameter and checks that the expiration date is within the - allowed range. - """ - user_id = self.register_user("kermit_delta", "user") - - self.hs.config.account_validity.startup_job_max_delta = self.max_delta - - now_ms = self.hs.get_clock().time_msec() - self.get_success(self.store._set_expiration_date_when_missing()) - - res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) - - self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) - self.assertLessEqual(res, now_ms + self.validity_period) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py deleted file mode 100644 index 02b5e9a8d0..0000000000 --- a/tests/rest/client/v2_alpha/test_relations.py +++ /dev/null @@ -1,724 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import json -import urllib -from typing import Optional - -from synapse.api.constants import EventTypes, RelationTypes -from synapse.rest import admin -from synapse.rest.client import login, register, relations, room - -from tests import unittest - - -class RelationsTestCase(unittest.HomeserverTestCase): - servlets = [ - relations.register_servlets, - room.register_servlets, - login.register_servlets, - register.register_servlets, - admin.register_servlets_for_client_rest_resource, - ] - hijack_auth = False - - def make_homeserver(self, reactor, clock): - # We need to enable msc1849 support for aggregations - config = self.default_config() - config["experimental_msc1849_support_enabled"] = True - - # We enable frozen dicts as relations/edits change event contents, so we - # want to test that we don't modify the events in the caches. - config["use_frozen_dicts"] = True - - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor, clock, hs): - self.user_id, self.user_token = self._create_user("alice") - self.user2_id, self.user2_token = self._create_user("bob") - - self.room = self.helper.create_room_as(self.user_id, tok=self.user_token) - self.helper.join(self.room, user=self.user2_id, tok=self.user2_token) - res = self.helper.send(self.room, body="Hi!", tok=self.user_token) - self.parent_id = res["event_id"] - - def test_send_relation(self): - """Tests that sending a relation using the new /send_relation works - creates the right shape of event. - """ - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEquals(200, channel.code, channel.json_body) - - event_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, event_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assert_dict( - { - "type": "m.reaction", - "sender": self.user_id, - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "key": "👍", - "rel_type": RelationTypes.ANNOTATION, - } - }, - }, - channel.json_body, - ) - - def test_deny_membership(self): - """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) - self.assertEquals(400, channel.code, channel.json_body) - - def test_deny_double_react(self): - """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(400, channel.code, channel.json_body) - - def test_basic_paginate_relations(self): - """Tests that calling pagination API correctly the latest relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") - self.assertEquals(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" - % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # We expect to get back a single pagination result, which is the full - # relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( - {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, - channel.json_body["chunk"][0], - ) - - # We also expect to get the original event (the id of which is self.parent_id) - self.assertEquals( - channel.json_body["original_event"]["event_id"], self.parent_id - ) - - # Make sure next_batch has something in it that looks like it could be a - # valid token. - self.assertIsInstance( - channel.json_body.get("next_batch"), str, channel.json_body - ) - - def test_repeated_paginate_relations(self): - """Test that if we paginate using a limit and tokens then we get the - expected events. - """ - - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") - self.assertEquals(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - prev_token = None - found_event_ids = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" - % (self.room, self.parent_id, from_token), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - next_batch = channel.json_body.get("next_batch") - - self.assertNotEquals(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) - - def test_aggregation_pagination_groups(self): - """Test that we can paginate annotation groups correctly.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} - for key in itertools.chain.from_iterable( - itertools.repeat(key, num) for key, num in sent_groups.items() - ): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key=key, - access_token=access_tokens[idx], - ) - self.assertEquals(200, channel.code, channel.json_body) - - idx += 1 - idx %= len(access_tokens) - - prev_token = None - found_groups = {} - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" - % (self.room, self.parent_id, from_token), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - for groups in channel.json_body["chunk"]: - # We only expect reactions - self.assertEqual(groups["type"], "m.reaction", channel.json_body) - - # We should only see each key once - self.assertNotIn(groups["key"], found_groups, channel.json_body) - - found_groups[groups["key"]] = groups["count"] - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEquals(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - self.assertEquals(sent_groups, found_groups) - - def test_aggregation_pagination_within_group(self): - """Test that we can paginate within an annotation group.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="👍", - access_token=access_tokens[idx], - ) - self.assertEquals(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - idx += 1 - - # Also send a different type of reaction so that we test we don't see it - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) - - prev_token = None - found_event_ids = [] - encoded_key = urllib.parse.quote_plus("👍".encode()) - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s" - "/aggregations/%s/%s/m.reaction/%s?limit=1%s" - % ( - self.room, - self.parent_id, - RelationTypes.ANNOTATION, - encoded_key, - from_token, - ), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEquals(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) - - def test_aggregation(self): - """Test that annotations get correctly aggregated.""" - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s" - % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals( - channel.json_body, - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - ) - - def test_aggregation_redactions(self): - """Test that annotations get correctly aggregated after a redaction.""" - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) - to_redact_event_id = channel.json_body["event_id"] - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Now lets redact one of the 'a' reactions - channel = self.make_request( - "POST", - "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), - access_token=self.user_token, - content={}, - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s" - % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals( - channel.json_body, - {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, - ) - - def test_aggregation_must_be_annotation(self): - """Test that aggregations must be annotations.""" - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" - % (self.room, self.parent_id, RelationTypes.REPLACE), - access_token=self.user_token, - ) - self.assertEquals(400, channel.code, channel.json_body) - - def test_aggregation_get_event(self): - """Test that annotations and references get correctly bundled when - getting the parent event. - """ - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) - reply_1 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) - reply_2 = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals( - channel.json_body["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - RelationTypes.REFERENCE: { - "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] - }, - }, - ) - - def test_edit(self): - """Test that a simple edit works.""" - - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - ) - self.assertEquals(200, channel.code, channel.json_body) - - edit_event_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals(channel.json_body["content"], new_body) - - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) - - def test_multi_edit(self): - """Test that multiple edits, including attempts by people who - shouldn't be allowed, are correctly handled. - """ - - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": {"msgtype": "m.text", "body": "First edit"}, - }, - ) - self.assertEquals(200, channel.code, channel.json_body) - - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - ) - self.assertEquals(200, channel.code, channel.json_body) - - edit_event_id = channel.json_body["event_id"] - - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message.WRONG_TYPE", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, - }, - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals(channel.json_body["content"], new_body) - - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) - - def test_edit_reply(self): - """Test that editing a reply works.""" - - # Create a reply to edit. - channel = self._send_relation( - RelationTypes.REFERENCE, - "m.room.message", - content={"msgtype": "m.text", "body": "A reply!"}, - ) - self.assertEquals(200, channel.code, channel.json_body) - reply = channel.json_body["event_id"] - - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - parent_id=reply, - ) - self.assertEquals(200, channel.code, channel.json_body) - - edit_event_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, reply), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # We expect to see the new body in the dict, as well as the reference - # metadata sill intact. - self.assertDictContainsSubset(new_body, channel.json_body["content"]) - self.assertDictContainsSubset( - { - "m.relates_to": { - "event_id": self.parent_id, - "key": None, - "rel_type": "m.reference", - } - }, - channel.json_body["content"], - ) - - # We expect that the edit relation appears in the unsigned relations - # section. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) - - def test_relations_redaction_redacts_edits(self): - """Test that edits of an event are redacted when the original event - is redacted. - """ - # Send a new event - res = self.helper.send(self.room, body="Heyo!", tok=self.user_token) - original_event_id = res["event_id"] - - # Add a relation - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - parent_id=original_event_id, - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": {"msgtype": "m.text", "body": "First edit"}, - }, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Check the relation is returned - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEquals(len(channel.json_body["chunk"]), 1) - - # Redact the original event - channel = self.make_request( - "PUT", - "/rooms/%s/redact/%s/%s" - % (self.room, original_event_id, "test_relations_redaction_redacts_edits"), - access_token=self.user_token, - content="{}", - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Try to check for remaining m.replace relations - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Check that no relations are returned - self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) - - def test_aggregations_redaction_prevents_access_to_aggregations(self): - """Test that annotations of an event are redacted when the original event - is redacted. - """ - # Send a new event - res = self.helper.send(self.room, body="Hello!", tok=self.user_token) - original_event_id = res["event_id"] - - # Add a relation - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Redact the original - channel = self.make_request( - "PUT", - "/rooms/%s/redact/%s/%s" - % ( - self.room, - original_event_id, - "test_aggregations_redaction_prevents_access_to_aggregations", - ), - access_token=self.user_token, - content="{}", - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Check that aggregations returns zero - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) - - def _send_relation( - self, - relation_type, - event_type, - key=None, - content: Optional[dict] = None, - access_token=None, - parent_id=None, - ): - """Helper function to send a relation pointing at `self.parent_id` - - Args: - relation_type (str): One of `RelationTypes` - event_type (str): The type of the event to create - parent_id (str): The event_id this relation relates to. If None, then self.parent_id - key (str|None): The aggregation key used for m.annotation relation - type. - content(dict|None): The content of the created event. - access_token (str|None): The access token used to send the relation, - defaults to `self.user_token` - - Returns: - FakeChannel - """ - if not access_token: - access_token = self.user_token - - query = "" - if key: - query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8")) - - original_id = parent_id if parent_id else self.parent_id - - channel = self.make_request( - "POST", - "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" - % (self.room, original_id, relation_type, event_type, query), - json.dumps(content or {}).encode("utf-8"), - access_token=access_token, - ) - return channel - - def _create_user(self, localpart): - user_id = self.register_user(localpart, "abc123") - access_token = self.login(localpart, "abc123") - - return user_id, access_token diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py deleted file mode 100644 index ee6b0b9ebf..0000000000 --- a/tests/rest/client/v2_alpha/test_report_event.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2021 Callum Brown -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import synapse.rest.admin -from synapse.rest.client import login, report_event, room - -from tests import unittest - - -class ReportEventTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - report_event.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - self.other_user = self.register_user("user", "pass") - self.other_user_tok = self.login("user", "pass") - - self.room_id = self.helper.create_room_as( - self.other_user, tok=self.other_user_tok, is_public=True - ) - self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) - resp = self.helper.send(self.room_id, tok=self.admin_user_tok) - self.event_id = resp["event_id"] - self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" - - def test_reason_str_and_score_int(self): - data = {"reason": "this makes me sad", "score": -100} - self._assert_status(200, data) - - def test_no_reason(self): - data = {"score": 0} - self._assert_status(200, data) - - def test_no_score(self): - data = {"reason": "this makes me sad"} - self._assert_status(200, data) - - def test_no_reason_and_no_score(self): - data = {} - self._assert_status(200, data) - - def test_reason_int_and_score_str(self): - data = {"reason": 10, "score": "string"} - self._assert_status(400, data) - - def test_reason_zero_and_score_blank(self): - data = {"reason": 0, "score": ""} - self._assert_status(400, data) - - def test_reason_and_score_null(self): - data = {"reason": None, "score": None} - self._assert_status(400, data) - - def _assert_status(self, response_status, data): - channel = self.make_request( - "POST", - self.report_path, - json.dumps(data), - access_token=self.other_user_tok, - ) - self.assertEqual( - response_status, int(channel.result["code"]), msg=channel.result["body"] - ) diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/v2_alpha/test_sendtodevice.py deleted file mode 100644 index 6db7062a8e..0000000000 --- a/tests/rest/client/v2_alpha/test_sendtodevice.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.rest import admin -from synapse.rest.client import login, sendtodevice, sync - -from tests.unittest import HomeserverTestCase, override_config - - -class SendToDeviceTestCase(HomeserverTestCase): - servlets = [ - admin.register_servlets, - login.register_servlets, - sendtodevice.register_servlets, - sync.register_servlets, - ] - - def test_user_to_user(self): - """A to-device message from one user to another should get delivered""" - - user1 = self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") - - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - # send the message - test_msg = {"foo": "bar"} - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.test/1234", - content={"messages": {user2: {"d2": test_msg}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # check it appears - channel = self.make_request("GET", "/sync", access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - expected_result = { - "events": [ - { - "sender": user1, - "type": "m.test", - "content": test_msg, - } - ] - } - self.assertEqual(channel.json_body["to_device"], expected_result) - - # it should re-appear if we do another sync - channel = self.make_request("GET", "/sync", access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(channel.json_body["to_device"], expected_result) - - # it should *not* appear if we do an incremental sync - sync_token = channel.json_body["next_batch"] - channel = self.make_request( - "GET", f"/sync?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) - - @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_local_room_key_request(self): - """m.room_key_request has special-casing; test from local user""" - user1 = self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") - - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - # send three messages - for i in range(3): - chan = self.make_request( - "PUT", - f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}", - content={"messages": {user2: {"d2": {"idx": i}}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # now sync: we should get two of the three - channel = self.make_request("GET", "/sync", access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 2) - for i in range(2): - self.assertEqual( - msgs[i], - {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}}, - ) - sync_token = channel.json_body["next_batch"] - - # ... time passes - self.reactor.advance(1) - - # and we can send more messages - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.room_key_request/3", - content={"messages": {user2: {"d2": {"idx": 3}}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # ... which should arrive - channel = self.make_request( - "GET", f"/sync?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 1) - self.assertEqual( - msgs[0], - {"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}}, - ) - - @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_remote_room_key_request(self): - """m.room_key_request has special-casing; test from remote user""" - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - federation_registry = self.hs.get_federation_registry() - - # send three messages - for i in range(3): - self.get_success( - federation_registry.on_edu( - "m.direct_to_device", - "remote_server", - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "messages": {user2: {"d2": {"idx": i}}}, - "message_id": f"{i}", - }, - ) - ) - - # now sync: we should get two of the three - channel = self.make_request("GET", "/sync", access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 2) - for i in range(2): - self.assertEqual( - msgs[i], - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "content": {"idx": i}, - }, - ) - sync_token = channel.json_body["next_batch"] - - # ... time passes - self.reactor.advance(1) - - # and we can send more messages - self.get_success( - federation_registry.on_edu( - "m.direct_to_device", - "remote_server", - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "messages": {user2: {"d2": {"idx": 3}}}, - "message_id": "3", - }, - ) - ) - - # ... which should arrive - channel = self.make_request( - "GET", f"/sync?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 1) - self.assertEqual( - msgs[0], - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "content": {"idx": 3}, - }, - ) diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py deleted file mode 100644 index 283eccd53f..0000000000 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright 2020 Half-Shot -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import synapse.rest.admin -from synapse.rest.client import login, room, shared_rooms - -from tests import unittest -from tests.server import FakeChannel - - -class UserSharedRoomsTest(unittest.HomeserverTestCase): - """ - Tests the UserSharedRoomsServlet. - """ - - servlets = [ - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - shared_rooms.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["update_user_directory"] = True - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.handler = hs.get_user_directory_handler() - - def _get_shared_rooms(self, token, other_user) -> FakeChannel: - return self.make_request( - "GET", - "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" - % other_user, - access_token=token, - ) - - def test_shared_room_list_public(self): - """ - A room should show up in the shared list of rooms between two users - if it is public. - """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) - - def test_shared_room_list_private(self): - """ - A room should show up in the shared list of rooms between two users - if it is private. - """ - self._check_shared_rooms_with( - room_one_is_public=False, room_two_is_public=False - ) - - def test_shared_room_list_mixed(self): - """ - The shared room list between two users should contain both public and private - rooms. - """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) - - def _check_shared_rooms_with( - self, room_one_is_public: bool, room_two_is_public: bool - ): - """Checks that shared public or private rooms between two users appear in - their shared room lists - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - # Create a room. user1 invites user2, who joins - room_id_one = self.helper.create_room_as( - u1, is_public=room_one_is_public, tok=u1_token - ) - self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) - self.helper.join(room_id_one, user=u2, tok=u2_token) - - # Check shared rooms from user1's perspective. - # We should see the one room in common - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room_id_one) - - # Create another room and invite user2 to it - room_id_two = self.helper.create_room_as( - u1, is_public=room_two_is_public, tok=u1_token - ) - self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) - self.helper.join(room_id_two, user=u2, tok=u2_token) - - # Check shared rooms again. We should now see both rooms. - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 2) - for room_id_id in channel.json_body["joined"]: - self.assertIn(room_id_id, [room_id_one, room_id_two]) - - def test_shared_room_list_after_leave(self): - """ - A room should no longer be considered shared if the other - user has left it. - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - # Assert user directory is not empty - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room) - - self.helper.leave(room, user=u1, tok=u1_token) - - # Check user1's view of shared rooms with user2 - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) - - # Check user2's view of shared rooms with user1 - channel = self._get_shared_rooms(u2_token, u1) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py deleted file mode 100644 index 95be369d4b..0000000000 --- a/tests/rest/client/v2_alpha/test_sync.py +++ /dev/null @@ -1,686 +0,0 @@ -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json - -import synapse.rest.admin -from synapse.api.constants import ( - EventContentFields, - EventTypes, - ReadReceiptEventFields, - RelationTypes, -) -from synapse.rest.client import knock, login, read_marker, receipts, room, sync - -from tests import unittest -from tests.federation.transport.test_knocking import ( - KnockingStrippedStateEventHelperMixin, -) -from tests.server import TimedOutException -from tests.unittest import override_config - - -class FilterTestCase(unittest.HomeserverTestCase): - - user_id = "@apple:test" - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - sync.register_servlets, - ] - - def test_sync_argless(self): - channel = self.make_request("GET", "/sync") - - self.assertEqual(channel.code, 200) - self.assertIn("next_batch", channel.json_body) - - -class SyncFilterTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - sync.register_servlets, - ] - - def test_sync_filter_labels(self): - """Test that we can filter by a label.""" - sync_filter = json.dumps( - { - "room": { - "timeline": { - "types": [EventTypes.Message], - "org.matrix.labels": ["#fun"], - } - } - } - ) - - events = self._test_sync_filter_labels(sync_filter) - - self.assertEqual(len(events), 2, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - - def test_sync_filter_not_labels(self): - """Test that we can filter by the absence of a label.""" - sync_filter = json.dumps( - { - "room": { - "timeline": { - "types": [EventTypes.Message], - "org.matrix.not_labels": ["#fun"], - } - } - } - ) - - events = self._test_sync_filter_labels(sync_filter) - - self.assertEqual(len(events), 3, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "without label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) - self.assertEqual( - events[2]["content"]["body"], "with two wrong labels", events[2] - ) - - def test_sync_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label.""" - sync_filter = json.dumps( - { - "room": { - "timeline": { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - } - } - ) - - events = self._test_sync_filter_labels(sync_filter) - - self.assertEqual(len(events), 1, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - - def _test_sync_filter_labels(self, sync_filter): - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test") - - room_id = self.helper.create_room_as(user_id, tok=tok) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=tok, - ) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - tok=tok, - ) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with wrong label", - EventContentFields.LABELS: ["#work"], - }, - tok=tok, - ) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with two wrong labels", - EventContentFields.LABELS: ["#work", "#notfun"], - }, - tok=tok, - ) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=tok, - ) - - channel = self.make_request( - "GET", "/sync?filter=%s" % sync_filter, access_token=tok - ) - self.assertEqual(channel.code, 200, channel.result) - - return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] - - -class SyncTypingTests(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - sync.register_servlets, - ] - user_id = True - hijack_auth = False - - def test_sync_backwards_typing(self): - """ - If the typing serial goes backwards and the typing handler is then reset - (such as when the master restarts and sets the typing serial to 0), we - do not incorrectly return typing information that had a serial greater - than the now-reset serial. - """ - typing_url = "/rooms/%s/typing/%s?access_token=%s" - sync_url = "/sync?timeout=3000000&access_token=%s&since=%s" - - # Register the user who gets notified - user_id = self.register_user("user", "pass") - access_token = self.login("user", "pass") - - # Register the user who sends the message - other_user_id = self.register_user("otheruser", "pass") - other_access_token = self.login("otheruser", "pass") - - # Create a room - room = self.helper.create_room_as(user_id, tok=access_token) - - # Invite the other person - self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) - - # The other user joins - self.helper.join(room=room, user=other_user_id, tok=other_access_token) - - # The other user sends some messages - self.helper.send(room, body="Hi!", tok=other_access_token) - self.helper.send(room, body="There!", tok=other_access_token) - - # Start typing. - channel = self.make_request( - "PUT", - typing_url % (room, other_user_id, other_access_token), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,)) - self.assertEquals(200, channel.code) - next_batch = channel.json_body["next_batch"] - - # Stop typing. - channel = self.make_request( - "PUT", - typing_url % (room, other_user_id, other_access_token), - b'{"typing": false}', - ) - self.assertEquals(200, channel.code) - - # Start typing. - channel = self.make_request( - "PUT", - typing_url % (room, other_user_id, other_access_token), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - # Should return immediately - channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) - next_batch = channel.json_body["next_batch"] - - # Reset typing serial back to 0, as if the master had. - typing = self.hs.get_typing_handler() - typing._latest_room_serial = 0 - - # Since it checks the state token, we need some state to update to - # invalidate the stream token. - self.helper.send(room, body="There!", tok=other_access_token) - - channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) - next_batch = channel.json_body["next_batch"] - - # This should time out! But it does not, because our stream token is - # ahead, and therefore it's saying the typing (that we've actually - # already seen) is new, since it's got a token above our new, now-reset - # stream token. - channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) - next_batch = channel.json_body["next_batch"] - - # Clear the typing information, so that it doesn't think everything is - # in the future. - typing._reset() - - # Now it SHOULD fail as it never completes! - with self.assertRaises(TimedOutException): - self.make_request("GET", sync_url % (access_token, next_batch)) - - -class SyncKnockTestCase( - unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin -): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - knock.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user (used to create the room to knock on). - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room we'll knock on. - self.room_id = self.helper.create_room_as( - self.user_id, - is_public=False, - room_version="7", - tok=self.tok, - ) - - # Register the second user (used to knock on the room). - self.knocker = self.register_user("knocker", "monkey") - self.knocker_tok = self.login("knocker", "monkey") - - # Perform an initial sync for the knocking user. - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] - - # Set up some room state to test with. - self.expected_room_state = self.send_example_state_events_to_room( - hs, self.room_id, self.user_id - ) - - @override_config({"experimental_features": {"msc2403_enabled": True}}) - def test_knock_room_state(self): - """Tests that /sync returns state from a room after knocking on it.""" - # Knock on a room - channel = self.make_request( - "POST", - "/_matrix/client/r0/knock/%s" % (self.room_id,), - b"{}", - self.knocker_tok, - ) - self.assertEquals(200, channel.code, channel.result) - - # We expect to see the knock event in the stripped room state later - self.expected_room_state[EventTypes.Member] = { - "content": {"membership": "knock", "displayname": "knocker"}, - "state_key": "@knocker:test", - } - - # Check that /sync includes stripped state from the room - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.knocker_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Extract the stripped room state events from /sync - knock_entry = channel.json_body["rooms"]["knock"] - room_state_events = knock_entry[self.room_id]["knock_state"]["events"] - - # Validate that the knock membership event came last - self.assertEqual(room_state_events[-1]["type"], EventTypes.Member) - - # Validate the stripped room state events - self.check_knock_room_state_against_room_state( - room_state_events, self.expected_room_state - ) - - -class ReadReceiptsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - receipts.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - # Register the second user - self.user2 = self.register_user("kermit2", "monkey") - self.tok2 = self.login("kermit2", "monkey") - - # Join the second user - self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - - @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_hidden_read_receipts(self): - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a read receipt to tell the server the first user's message was read - body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") - channel = self.make_request( - "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), - body, - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - - # Test that the first user can't see the other user's hidden read receipt - self.assertEqual(self._get_read_receipt(), None) - - def test_read_receipt_with_empty_body(self): - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a read receipt for this message with an empty body - channel = self.make_request( - "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - - def _get_read_receipt(self): - """Syncs and returns the read receipt.""" - - # Checks if event is a read receipt - def is_read_receipt(event): - return event["type"] == "m.receipt" - - # Sync - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] - - # Return the read receipt - ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ - "ephemeral" - ]["events"] - return next(filter(is_read_receipt, ephemeral_events), None) - - -class UnreadMessagesTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - read_marker.register_servlets, - room.register_servlets, - sync.register_servlets, - receipts.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user (used to check the unread counts). - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room we'll check unread counts for. - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - # Register the second user (used to send events to the room). - self.user2 = self.register_user("kermit2", "monkey") - self.tok2 = self.login("kermit2", "monkey") - - # Change the power levels of the room so that the second user can send state - # events. - self.helper.send_state( - self.room_id, - EventTypes.PowerLevels, - { - "users": {self.user_id: 100, self.user2: 100}, - "users_default": 0, - "events": { - "m.room.name": 50, - "m.room.power_levels": 100, - "m.room.history_visibility": 100, - "m.room.canonical_alias": 50, - "m.room.avatar": 50, - "m.room.tombstone": 100, - "m.room.server_acl": 100, - "m.room.encryption": 100, - }, - "events_default": 0, - "state_default": 50, - "ban": 50, - "kick": 50, - "redact": 50, - "invite": 0, - }, - tok=self.tok, - ) - - def test_unread_counts(self): - """Tests that /sync returns the right value for the unread count (MSC2654).""" - - # Check that our own messages don't increase the unread count. - self.helper.send(self.room_id, "hello", tok=self.tok) - self._check_unread_count(0) - - # Join the new user and check that this doesn't increase the unread count. - self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - self._check_unread_count(0) - - # Check that the new user sending a message increases our unread count. - res = self.helper.send(self.room_id, "hello", tok=self.tok2) - self._check_unread_count(1) - - # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({"m.read": res["event_id"]}).encode("utf8") - channel = self.make_request( - "POST", - "/rooms/%s/read_markers" % self.room_id, - body, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check that the unread counter is back to 0. - self._check_unread_count(0) - - # Check that hidden read receipts don't break unread counts - res = self.helper.send(self.room_id, "hello", tok=self.tok2) - self._check_unread_count(1) - - # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") - channel = self.make_request( - "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), - body, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check that the unread counter is back to 0. - self._check_unread_count(0) - - # Check that room name changes increase the unread counter. - self.helper.send_state( - self.room_id, - "m.room.name", - {"name": "my super room"}, - tok=self.tok2, - ) - self._check_unread_count(1) - - # Check that room topic changes increase the unread counter. - self.helper.send_state( - self.room_id, - "m.room.topic", - {"topic": "welcome!!!"}, - tok=self.tok2, - ) - self._check_unread_count(2) - - # Check that encrypted messages increase the unread counter. - self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) - self._check_unread_count(3) - - # Check that custom events with a body increase the unread counter. - self.helper.send_event( - self.room_id, - "org.matrix.custom_type", - {"body": "hello"}, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that edits don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "body": "hello", - "msgtype": "m.text", - "m.relates_to": {"rel_type": RelationTypes.REPLACE}, - }, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that notices don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"body": "hello", "msgtype": "m.notice"}, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that tombstone events changes increase the unread counter. - self.helper.send_state( - self.room_id, - EventTypes.Tombstone, - {"replacement_room": "!someroom:test"}, - tok=self.tok2, - ) - self._check_unread_count(5) - - def _check_unread_count(self, expected_count: int): - """Syncs and compares the unread count with the expected value.""" - - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.tok, - ) - - self.assertEqual(channel.code, 200, channel.json_body) - - room_entry = channel.json_body["rooms"]["join"][self.room_id] - self.assertEqual( - room_entry["org.matrix.msc2654.unread_count"], - expected_count, - room_entry, - ) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] - - -class SyncCacheTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - sync.register_servlets, - ] - - def test_noop_sync_does_not_tightloop(self): - """If the sync times out, we shouldn't cache the result - - Essentially a regression test for #8518. - """ - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # we should immediately get an initial sync response - channel = self.make_request("GET", "/sync", access_token=self.tok) - self.assertEqual(channel.code, 200, channel.json_body) - - # now, make an incremental sync request, with a timeout - next_batch = channel.json_body["next_batch"] - channel = self.make_request( - "GET", - f"/sync?since={next_batch}&timeout=10000", - access_token=self.tok, - await_result=False, - ) - # that should block for 10 seconds - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=9900) - channel.await_result(timeout_ms=200) - self.assertEqual(channel.code, 200, channel.json_body) - - # we expect the next_batch in the result to be the same as before - self.assertEqual(channel.json_body["next_batch"], next_batch) - - # another incremental sync should also block. - channel = self.make_request( - "GET", - f"/sync?since={next_batch}&timeout=10000", - access_token=self.tok, - await_result=False, - ) - # that should block for 10 seconds - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=9900) - channel.await_result(timeout_ms=200) - self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py deleted file mode 100644 index 72f976d8e2..0000000000 --- a/tests/rest/client/v2_alpha/test_upgrade_room.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.rest import admin -from synapse.rest.client import login, room, room_upgrade_rest_servlet - -from tests import unittest -from tests.server import FakeChannel - - -class UpgradeRoomTest(unittest.HomeserverTestCase): - servlets = [ - admin.register_servlets, - login.register_servlets, - room.register_servlets, - room_upgrade_rest_servlet.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.handler = hs.get_user_directory_handler() - - self.creator = self.register_user("creator", "pass") - self.creator_token = self.login(self.creator, "pass") - - self.other = self.register_user("user", "pass") - self.other_token = self.login(self.other, "pass") - - self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token) - self.helper.join(self.room_id, self.other, tok=self.other_token) - - def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel: - # We never want a cached response. - self.reactor.advance(5 * 60 + 1) - - return self.make_request( - "POST", - "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id, - # This will upgrade a room to the same version, but that's fine. - content={"new_version": DEFAULT_ROOM_VERSION}, - access_token=token or self.creator_token, - ) - - def test_upgrade(self): - """ - Upgrading a room should work fine. - """ - channel = self._upgrade_room() - self.assertEquals(200, channel.code, channel.result) - self.assertIn("replacement_room", channel.json_body) - - def test_not_in_room(self): - """ - Upgrading a room should work fine. - """ - # THe user isn't in the room. - roomless = self.register_user("roomless", "pass") - roomless_token = self.login(roomless, "pass") - - channel = self._upgrade_room(roomless_token) - self.assertEquals(403, channel.code, channel.result) - - def test_power_levels(self): - """ - Another user can upgrade the room if their power level is increased. - """ - # The other user doesn't have the proper power level. - channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) - - # Increase the power levels so that this user can upgrade. - power_levels = self.helper.get_state( - self.room_id, - "m.room.power_levels", - tok=self.creator_token, - ) - power_levels["users"][self.other] = 100 - self.helper.send_state( - self.room_id, - "m.room.power_levels", - body=power_levels, - tok=self.creator_token, - ) - - # The upgrade should succeed! - channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) - - def test_power_levels_user_default(self): - """ - Another user can upgrade the room if the default power level for users is increased. - """ - # The other user doesn't have the proper power level. - channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) - - # Increase the power levels so that this user can upgrade. - power_levels = self.helper.get_state( - self.room_id, - "m.room.power_levels", - tok=self.creator_token, - ) - power_levels["users_default"] = 100 - self.helper.send_state( - self.room_id, - "m.room.power_levels", - body=power_levels, - tok=self.creator_token, - ) - - # The upgrade should succeed! - channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) - - def test_power_levels_tombstone(self): - """ - Another user can upgrade the room if they can send the tombstone event. - """ - # The other user doesn't have the proper power level. - channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) - - # Increase the power levels so that this user can upgrade. - power_levels = self.helper.get_state( - self.room_id, - "m.room.power_levels", - tok=self.creator_token, - ) - power_levels["events"]["m.room.tombstone"] = 0 - self.helper.send_state( - self.room_id, - "m.room.power_levels", - body=power_levels, - tok=self.creator_token, - ) - - # The upgrade should succeed! - channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) - - power_levels = self.helper.get_state( - self.room_id, - "m.room.power_levels", - tok=self.creator_token, - ) - self.assertNotIn(self.other, power_levels["users"]) diff --git a/tests/unittest.py b/tests/unittest.py index 3eec9c4d5b..f2c90cc47b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -252,7 +252,7 @@ class HomeserverTestCase(TestCase): reactor=self.reactor, ) - from tests.rest.client.v1.utils import RestHelper + from tests.rest.client.utils import RestHelper self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None)) -- cgit 1.4.1