From 1a9f531c79f8a043db6f151c5393f23aa5675800 Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Tue, 17 Aug 2021 14:22:45 +0100 Subject: Port the PresenceRouter module interface to the new generic interface (#10524) Port the PresenceRouter module interface to the new generic interface introduced in v1.37.0 --- tests/events/test_presence_router.py | 109 +++++++++++++++++++++++++++++++++-- 1 file changed, 104 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 6b87f571b8..3b3866bff8 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -17,7 +17,7 @@ from unittest.mock import Mock import attr from synapse.api.constants import EduTypes -from synapse.events.presence_router import PresenceRouter +from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState from synapse.module_api import ModuleApi @@ -34,7 +34,7 @@ class PresenceRouterTestConfig: users_who_should_receive_all_presence = attr.ib(type=List[str], default=[]) -class PresenceRouterTestModule: +class LegacyPresenceRouterTestModule: def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi): self._config = config self._module_api = module_api @@ -77,6 +77,53 @@ class PresenceRouterTestModule: return config +class PresenceRouterTestModule: + def __init__(self, config: PresenceRouterTestConfig, api: ModuleApi): + self._config = config + self._module_api = api + api.register_presence_router_callbacks( + get_users_for_states=self.get_users_for_states, + get_interested_users=self.get_interested_users, + ) + + async def get_users_for_states( + self, state_updates: Iterable[UserPresenceState] + ) -> Dict[str, Set[UserPresenceState]]: + users_to_state = { + user_id: set(state_updates) + for user_id in self._config.users_who_should_receive_all_presence + } + return users_to_state + + async def get_interested_users( + self, user_id: str + ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + if user_id in self._config.users_who_should_receive_all_presence: + return PresenceRouter.ALL_USERS + + return set() + + @staticmethod + def parse_config(config_dict: dict) -> PresenceRouterTestConfig: + """Parse a configuration dictionary from the homeserver config, do + some validation and return a typed PresenceRouterConfig. + + Args: + config_dict: The configuration dictionary. + + Returns: + A validated config object. + """ + # Initialise a typed config object + config = PresenceRouterTestConfig() + + config.users_who_should_receive_all_presence = config_dict.get( + "users_who_should_receive_all_presence" + ) + + return config + + class PresenceRouterTestCase(FederatingHomeserverTestCase): servlets = [ admin.register_servlets, @@ -86,9 +133,17 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ] def make_homeserver(self, reactor, clock): - return self.setup_test_homeserver( + hs = self.setup_test_homeserver( federation_transport_client=Mock(spec=["send_transaction"]), ) + # Load the modules into the homeserver + module_api = hs.get_module_api() + for module, config in hs.config.modules.loaded_modules: + module(config=config, api=module_api) + + load_legacy_presence_router(hs) + + return hs def prepare(self, reactor, clock, homeserver): self.sync_handler = self.hs.get_sync_handler() @@ -98,7 +153,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): { "presence": { "presence_router": { - "module": __name__ + ".PresenceRouterTestModule", + "module": __name__ + ".LegacyPresenceRouterTestModule", "config": { "users_who_should_receive_all_presence": [ "@presence_gobbler:test", @@ -109,7 +164,28 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): "send_federation": True, } ) + def test_receiving_all_presence_legacy(self): + self.receiving_all_presence_test_body() + + @override_config( + { + "modules": [ + { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler:test", + ] + }, + }, + ], + "send_federation": True, + } + ) def test_receiving_all_presence(self): + self.receiving_all_presence_test_body() + + def receiving_all_presence_test_body(self): """Test that a user that does not share a room with another other can receive presence for them, due to presence routing. """ @@ -203,7 +279,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): { "presence": { "presence_router": { - "module": __name__ + ".PresenceRouterTestModule", + "module": __name__ + ".LegacyPresenceRouterTestModule", "config": { "users_who_should_receive_all_presence": [ "@presence_gobbler1:test", @@ -216,7 +292,30 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): "send_federation": True, } ) + def test_send_local_online_presence_to_with_module_legacy(self): + self.send_local_online_presence_to_with_module_test_body() + + @override_config( + { + "modules": [ + { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler1:test", + "@presence_gobbler2:test", + "@far_away_person:island", + ] + }, + }, + ], + "send_federation": True, + } + ) def test_send_local_online_presence_to_with_module(self): + self.send_local_online_presence_to_with_module_test_body() + + def send_local_online_presence_to_with_module_test_body(self): """Tests that send_local_presence_to_users sends local online presence to a set of specified local and remote users, with a custom PresenceRouter module enabled. """ -- cgit 1.5.1 From 430241a1e9fd4fbb82d83958b61bbd66c9ba3505 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 17 Aug 2021 22:19:13 +0200 Subject: Remove deprecated Shutdown Room and Purge Room Admin API (#8830) --- changelog.d/8830.removal | 1 + docs/SUMMARY.md | 2 - docs/admin_api/purge_room.md | 21 ---- docs/admin_api/shutdown_room.md | 102 ------------------- docs/upgrade.md | 13 +++ synapse/rest/admin/__init__.py | 4 - synapse/rest/admin/purge_room_servlet.py | 58 ----------- synapse/rest/admin/rooms.py | 35 ------- tests/rest/admin/test_room.py | 162 ------------------------------- 9 files changed, 14 insertions(+), 384 deletions(-) create mode 100644 changelog.d/8830.removal delete mode 100644 docs/admin_api/purge_room.md delete mode 100644 docs/admin_api/shutdown_room.md delete mode 100644 synapse/rest/admin/purge_room_servlet.py (limited to 'tests') diff --git a/changelog.d/8830.removal b/changelog.d/8830.removal new file mode 100644 index 0000000000..b3a93a9af2 --- /dev/null +++ b/changelog.d/8830.removal @@ -0,0 +1 @@ +Remove deprecated Shutdown Room and Purge Room Admin API. \ No newline at end of file diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 3d320a1c43..0714f151fd 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -51,12 +51,10 @@ - [Event Reports](admin_api/event_reports.md) - [Media](admin_api/media_admin_api.md) - [Purge History](admin_api/purge_history_api.md) - - [Purge Rooms](admin_api/purge_room.md) - [Register Users](admin_api/register_api.md) - [Manipulate Room Membership](admin_api/room_membership.md) - [Rooms](admin_api/rooms.md) - [Server Notices](admin_api/server_notices.md) - - [Shutdown Room](admin_api/shutdown_room.md) - [Statistics](admin_api/statistics.md) - [Users](admin_api/user_admin_api.md) - [Server Version](admin_api/version_api.md) diff --git a/docs/admin_api/purge_room.md b/docs/admin_api/purge_room.md deleted file mode 100644 index 54fea2db6d..0000000000 --- a/docs/admin_api/purge_room.md +++ /dev/null @@ -1,21 +0,0 @@ -Deprecated: Purge room API -========================== - -**The old Purge room API is deprecated and will be removed in a future release. -See the new [Delete Room API](rooms.md#delete-room-api) for more details.** - -This API will remove all trace of a room from your database. - -All local users must have left the room before it can be removed. - -The API is: - -``` -POST /_synapse/admin/v1/purge_room - -{ - "room_id": "!room:id" -} -``` - -You must authenticate using the access token of an admin user. diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md deleted file mode 100644 index 856a629487..0000000000 --- a/docs/admin_api/shutdown_room.md +++ /dev/null @@ -1,102 +0,0 @@ -# Deprecated: Shutdown room API - -**The old Shutdown room API is deprecated and will be removed in a future release. -See the new [Delete Room API](rooms.md#delete-room-api) for more details.** - -Shuts down a room, preventing new joins and moves local users and room aliases automatically -to a new room. The new room will be created with the user specified by the -`new_room_user_id` parameter as room administrator and will contain a message -explaining what happened. Users invited to the new room will have power level --10 by default, and thus be unable to speak. The old room's power levels will be changed to -disallow any further invites or joins. - -The local server will only have the power to move local user and room aliases to -the new room. Users on other servers will be unaffected. - -## API - -You will need to authenticate with an access token for an admin user. - -### URL - -`POST /_synapse/admin/v1/shutdown_room/{room_id}` - -### URL Parameters - -* `room_id` - The ID of the room (e.g `!someroom:example.com`) - -### JSON Body Parameters - -* `new_room_user_id` - Required. A string representing the user ID of the user that will admin - the new room that all users in the old room will be moved to. -* `room_name` - Optional. A string representing the name of the room that new users will be - invited to. -* `message` - Optional. A string containing the first message that will be sent as - `new_room_user_id` in the new room. Ideally this will clearly convey why the - original room was shut down. - -If not specified, the default value of `room_name` is "Content Violation -Notification". The default value of `message` is "Sharing illegal content on -othis server is not permitted and rooms in violation will be blocked." - -### Response Parameters - -* `kicked_users` - An integer number representing the number of users that - were kicked. -* `failed_to_kick_users` - An integer number representing the number of users - that were not kicked. -* `local_aliases` - An array of strings representing the local aliases that were migrated from - the old room to the new. -* `new_room_id` - A string representing the room ID of the new room. - -## Example - -Request: - -``` -POST /_synapse/admin/v1/shutdown_room/!somebadroom%3Aexample.com - -{ - "new_room_user_id": "@someuser:example.com", - "room_name": "Content Violation Notification", - "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service." -} -``` - -Response: - -``` -{ - "kicked_users": 5, - "failed_to_kick_users": 0, - "local_aliases": ["#badroom:example.com", "#evilsaloon:example.com], - "new_room_id": "!newroomid:example.com", -}, -``` - -## Undoing room shutdowns - -*Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level, -the structure can and does change without notice. - -First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it -never happened - work has to be done to move forward instead of resetting the past. In fact, in some cases it might not be possible -to recover at all: - -* If the room was invite-only, your users will need to be re-invited. -* If the room no longer has any members at all, it'll be impossible to rejoin. -* The first user to rejoin will have to do so via an alias on a different server. - -With all that being said, if you still want to try and recover the room: - -1. For safety reasons, shut down Synapse. -2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';` - * For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`. - * The room ID is the same one supplied to the shutdown room API, not the Content Violation room. -3. Restart Synapse. - -You will have to manually handle, if you so choose, the following: - -* Aliases that would have been redirected to the Content Violation room. -* Users that would have been booted from the room (and will have been force-joined to the Content Violation room). -* Removal of the Content Violation room if desired. diff --git a/docs/upgrade.md b/docs/upgrade.md index 8831c9d6cf..3ac2387e2a 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -85,6 +85,19 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.xx.0 + +## Removal of old Room Admin API + +The following admin APIs were deprecated in [Synapse 1.25](https://github.com/matrix-org/synapse/blob/v1.25.0/CHANGES.md#removal-warning) +(released on 2021-01-13) and have now been removed: + +- `POST /_synapse/admin/v1/purge_room` +- `POST /_synapse/admin/v1/shutdown_room/` + +Any scripts still using the above APIs should be converted to use the +[Delete Room API](https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#delete-room-api). + # Upgrading to v1.xx.0 diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 8a91068092..2acaea3003 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -36,7 +36,6 @@ from synapse.rest.admin.event_reports import ( ) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo -from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, ForwardExtremitiesRestServlet, @@ -47,7 +46,6 @@ from synapse.rest.admin.rooms import ( RoomMembersRestServlet, RoomRestServlet, RoomStateRestServlet, - ShutdownRoomRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet @@ -221,7 +219,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomMembersRestServlet(hs).register(http_server) DeleteRoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) - PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) @@ -255,7 +252,6 @@ def register_servlets_for_client_rest_resource( PurgeHistoryRestServlet(hs).register(http_server) ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) - ShutdownRoomRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) DeleteGroupAdminRestServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py deleted file mode 100644 index 2365ff7a0f..0000000000 --- a/synapse/rest/admin/purge_room_servlet.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import TYPE_CHECKING, Tuple - -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, -) -from synapse.http.site import SynapseRequest -from synapse.rest.admin import assert_requester_is_admin -from synapse.rest.admin._base import admin_patterns -from synapse.types import JsonDict - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class PurgeRoomServlet(RestServlet): - """Servlet which will remove all trace of a room from the database - - POST /_synapse/admin/v1/purge_room - { - "room_id": "!room:id" - } - - returns: - - {} - """ - - PATTERNS = admin_patterns("/purge_room$") - - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.auth = hs.get_auth() - self.pagination_handler = hs.get_pagination_handler() - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ("room_id",)) - - await self.pagination_handler.purge_room(body["room_id"]) - - return 200, {} diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 975c28b225..ad83d4b54c 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -46,41 +46,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ShutdownRoomRestServlet(RestServlet): - """Shuts down a room by removing all local users from the room and blocking - all future invites and joins to the room. Any local aliases will be repointed - to a new room created by `new_room_user_id` and kicked users will be auto - joined to the new room. - """ - - PATTERNS = admin_patterns("/shutdown_room/(?P[^/]+)") - - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.auth = hs.get_auth() - self.room_shutdown_handler = hs.get_room_shutdown_handler() - - async def on_POST( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) - - content = parse_json_object_from_request(request) - assert_params_in_dict(content, ["new_room_user_id"]) - - ret = await self.room_shutdown_handler.shutdown_room( - room_id=room_id, - new_room_user_id=content["new_room_user_id"], - new_room_name=content.get("room_name"), - message=content.get("message"), - requester_user_id=requester.user.to_string(), - block=True, - ) - - return (200, ret) - - class DeleteRoomRestServlet(RestServlet): """Delete a room from server. diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index c9d4731017..40e032df7f 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -29,123 +29,6 @@ from tests import unittest """Tests admin REST events for /rooms paths.""" -class ShutdownRoomTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - events.register_servlets, - room.register_servlets, - room.register_deprecated_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.event_creation_handler = hs.get_event_creation_handler() - hs.config.user_consent_version = "1" - - consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" - self.event_creation_handler._consent_uri_builder = consent_uri_builder - - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.other_user = self.register_user("user", "pass") - self.other_user_token = self.login("user", "pass") - - # Mark the admin user as having consented - self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) - - def test_shutdown_room_consent(self): - """Test that we can shutdown rooms with local users who have not - yet accepted the privacy policy. This used to fail when we tried to - force part the user from the old room. - """ - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Assert one user in room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([self.other_user], users_in_room) - - # Enable require consent to send events - self.event_creation_handler._block_events_without_consent_error = "Error" - - # Assert that the user is getting consent error - self.helper.send( - room_id, body="foo", tok=self.other_user_token, expect_code=403 - ) - - # Test that the admin can still send shutdown - url = "/_synapse/admin/v1/shutdown_room/" + room_id - channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert there is now no longer anyone in the room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([], users_in_room) - - def test_shutdown_room_block_peek(self): - """Test that a world_readable room can no longer be peeked into after - it has been shut down. - """ - - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Enable world readable - url = "rooms/%s/state/m.room.history_visibility" % (room_id,) - channel = self.make_request( - "PUT", - url.encode("ascii"), - json.dumps({"history_visibility": "world_readable"}), - access_token=self.other_user_token, - ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the admin can still send shutdown - url = "/_synapse/admin/v1/shutdown_room/" + room_id - channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert we can no longer peek into the room - self._assert_peek(room_id, expect_code=403) - - def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room.""" - - url = "rooms/%s/initialSync" % (room_id,) - channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - url = "events?timeout=0&room_id=" + room_id - channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - @parameterized_class( ("method", "url_template"), [ @@ -557,51 +440,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) -class PurgeRoomTestCase(unittest.HomeserverTestCase): - """Test /purge_room admin API.""" - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - def test_purge_room(self): - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # All users have to have left the room. - self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) - - url = "/_synapse/admin/v1/purge_room" - channel = self.make_request( - "POST", - url.encode("ascii"), - {"room_id": room_id}, - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the following tables have been purged of all rows related to the room. - for table in PURGE_TABLES: - count = self.get_success( - self.store.db_pool.simple_select_one_onecol( - table=table, - keyvalues={"room_id": room_id}, - retcol="COUNT(*)", - desc="test_purge_room", - ) - ) - - self.assertEqual(count, 0, msg=f"Rows not purged in {table}") - - class RoomTestCase(unittest.HomeserverTestCase): """Test /room admin API.""" -- cgit 1.5.1 From bec01c075829730cf467572e2fcf93e15372b0e9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 18 Aug 2021 09:22:07 -0400 Subject: Convert room member storage tuples to attrs. (#10629) Instead of using namedtuples. This helps with asserting type hints and code completion. --- changelog.d/10629.misc | 1 + synapse/handlers/initial_sync.py | 2 +- synapse/handlers/sync.py | 18 +++++----- synapse/storage/databases/main/roommember.py | 8 +++-- synapse/storage/databases/main/user_directory.py | 2 +- synapse/storage/roommember.py | 43 ++++++++++++++++-------- tests/replication/slave/storage/test_events.py | 9 +++-- 7 files changed, 54 insertions(+), 29 deletions(-) create mode 100644 changelog.d/10629.misc (limited to 'tests') diff --git a/changelog.d/10629.misc b/changelog.d/10629.misc new file mode 100644 index 0000000000..cca1eb6c57 --- /dev/null +++ b/changelog.d/10629.misc @@ -0,0 +1 @@ +Convert room member storage tuples to `attrs` classes. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index e1c544a3c9..4e8f7f1d85 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -151,7 +151,7 @@ class InitialSyncHandler(BaseHandler): limit = 10 async def handle_room(event: RoomsForUser): - d = { + d: JsonDict = { "room_id": event.room_id, "membership": event.membership, "visibility": ( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index eba915819e..b7b299961f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -701,7 +701,7 @@ class SyncHandler: name_id = state_ids.get((EventTypes.Name, "")) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) - summary = {} + summary: JsonDict = {} empty_ms = MemberSummary([], 0) # TODO: only send these when they change. @@ -2076,21 +2076,23 @@ class SyncHandler: # If the membership's stream ordering is after the given stream # ordering, we need to go and work out if the user was in the room # before. - for room_id, event_pos in joined_rooms: - if not event_pos.persisted_after(room_key): - joined_room_ids.add(room_id) + for joined_room in joined_rooms: + if not joined_room.event_pos.persisted_after(room_key): + joined_room_ids.add(joined_room.room_id) continue - logger.info("User joined room after current token: %s", room_id) + logger.info("User joined room after current token: %s", joined_room.room_id) extrems = ( await self.store.get_forward_extremities_for_room_at_stream_ordering( - room_id, event_pos.stream + joined_room.room_id, joined_room.event_pos.stream ) ) - users_in_room = await self.state.get_current_users_in_room(room_id, extrems) + users_in_room = await self.state.get_current_users_in_room( + joined_room.room_id, extrems + ) if user_id in users_in_room: - joined_room_ids.add(room_id) + joined_room_ids.add(joined_room.room_id) return frozenset(joined_room_ids) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index e8157ba3d4..c2f6b9d63d 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -307,7 +307,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached() - async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser: + async def get_invited_rooms_for_local_user( + self, user_id: str + ) -> List[RoomsForUser]: """Get all the rooms the *local* user is invited to. Args: @@ -522,7 +524,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): _get_users_server_still_shares_room_with_txn, ) - async def get_rooms_for_user(self, user_id: str, on_invalidate=None): + async def get_rooms_for_user( + self, user_id: str, on_invalidate=None + ) -> FrozenSet[str]: """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 9d28d69ac7..65dde67ae9 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -365,7 +365,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return False async def update_profile_in_user_dir( - self, user_id: str, display_name: str, avatar_url: str + self, user_id: str, display_name: Optional[str], avatar_url: Optional[str] ) -> None: """ Update or add a user's profile in the user directory. diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index c34fbf21bc..0ff66debdf 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -14,25 +14,40 @@ # limitations under the License. import logging -from collections import namedtuple +from typing import List, Optional, Tuple + +import attr + +from synapse.types import PersistedEventPosition logger = logging.getLogger(__name__) -RoomsForUser = namedtuple( - "RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering") -) +@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True) +class RoomsForUser: + room_id: str + sender: str + membership: str + event_id: str + stream_ordering: int + + +@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True) +class GetRoomsForUserWithStreamOrdering: + room_id: str + event_pos: PersistedEventPosition -GetRoomsForUserWithStreamOrdering = namedtuple( - "GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos") -) +@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True) +class ProfileInfo: + avatar_url: Optional[str] + display_name: Optional[str] -# We store this using a namedtuple so that we save about 3x space over using a -# dict. -ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name")) -# "members" points to a truncated list of (user_id, event_id) tuples for users of -# a given membership type, suitable for use in calculating heroes for a room. -# "count" points to the total numberr of users of a given membership type. -MemberSummary = namedtuple("MemberSummary", ("members", "count")) +@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True) +class MemberSummary: + # A truncated list of (user_id, event_id) tuples for users of a given + # membership type, suitable for use in calculating heroes for a room. + members: List[Tuple[str, str]] + # The total number of users of a given membership type. + count: int diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index db80a0bdbd..8d1b0606c4 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -20,7 +20,7 @@ from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.storage.roommember import RoomsForUser +from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.types import PersistedEventPosition from tests.server import FakeTransport @@ -216,7 +216,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_rooms_for_user_with_stream_ordering", (USER_ID_2,), - {(ROOM_ID, expected_pos)}, + {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, ) def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): @@ -305,7 +305,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): expected_pos = PersistedEventPosition( "master", j2.internal_metadata.stream_ordering ) - self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)}) + self.assertEqual( + joined_rooms, + {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, + ) event_id = 0 -- cgit 1.5.1 From 220f901229a506a82aedc51c5923768bf935ea4f Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 19 Aug 2021 11:25:05 +0200 Subject: Remove not needed database updates in modify user admin API (#10627) --- changelog.d/10627.misc | 1 + docs/admin_api/user_admin_api.md | 8 +++- synapse/rest/admin/users.py | 55 ++++++++++++++--------- synapse/storage/databases/main/registration.py | 25 ++++++++--- tests/rest/admin/test_user.py | 62 ++++++++++++++++++++++++-- 5 files changed, 118 insertions(+), 33 deletions(-) create mode 100644 changelog.d/10627.misc (limited to 'tests') diff --git a/changelog.d/10627.misc b/changelog.d/10627.misc new file mode 100644 index 0000000000..e6d314976e --- /dev/null +++ b/changelog.d/10627.misc @@ -0,0 +1 @@ +Remove not needed database updates in modify user admin API. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 6a9335d6ec..60dc913915 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -21,11 +21,15 @@ It returns a JSON body like the following: "threepids": [ { "medium": "email", - "address": "" + "address": "", + "added_at": 1586458409743, + "validated_at": 1586458409743 }, { "medium": "email", - "address": "" + "address": "", + "added_at": 1586458409743, + "validated_at": 1586458409743 } ], "avatar_url": "", diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 3c8a0c6883..c1a1ba645e 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet): if not isinstance(deactivate, bool): raise SynapseError(400, "'deactivated' parameter is not of type boolean") - # convert into List[Tuple[str, str]] + # convert List[Dict[str, str]] into Set[Tuple[str, str]] if external_ids is not None: - new_external_ids = [] - for external_id in external_ids: - new_external_ids.append( - (external_id["auth_provider"], external_id["external_id"]) - ) + new_external_ids = { + (external_id["auth_provider"], external_id["external_id"]) + for external_id in external_ids + } + + # convert List[Dict[str, str]] into Set[Tuple[str, str]] + if threepids is not None: + new_threepids = { + (threepid["medium"], threepid["address"]) for threepid in threepids + } if user: # modify user if "displayname" in body: @@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet): ) if threepids is not None: - # remove old threepids from user - old_threepids = await self.store.user_get_threepids(user_id) - for threepid in old_threepids: + # get changed threepids (added and removed) + # convert List[Dict[str, Any]] into Set[Tuple[str, str]] + cur_threepids = { + (threepid["medium"], threepid["address"]) + for threepid in await self.store.user_get_threepids(user_id) + } + add_threepids = new_threepids - cur_threepids + del_threepids = cur_threepids - new_threepids + + # remove old threepids + for medium, address in del_threepids: try: await self.auth_handler.delete_threepid( - user_id, threepid["medium"], threepid["address"], None + user_id, medium, address, None ) except Exception: logger.exception("Failed to remove threepids") raise SynapseError(500, "Failed to remove threepids") - # add new threepids to user + # add new threepids current_time = self.hs.get_clock().time_msec() - for threepid in threepids: + for medium, address in add_threepids: await self.auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], current_time + user_id, medium, address, current_time ) if external_ids is not None: # get changed external_ids (added and removed) - cur_external_ids = await self.store.get_external_ids_by_user(user_id) - add_external_ids = set(new_external_ids) - set(cur_external_ids) - del_external_ids = set(cur_external_ids) - set(new_external_ids) + cur_external_ids = set( + await self.store.get_external_ids_by_user(user_id) + ) + add_external_ids = new_external_ids - cur_external_ids + del_external_ids = cur_external_ids - new_external_ids # remove old external_ids for auth_provider, external_id in del_external_ids: @@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet): if threepids is not None: current_time = self.hs.get_clock().time_msec() - for threepid in threepids: + for medium, address in new_threepids: await self.auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], current_time + user_id, medium, address, current_time ) if ( self.hs.config.email_enable_notifs @@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet): kind="email", app_id="m.email", app_display_name="Email Notifications", - device_display_name=threepid["address"], - pushkey=threepid["address"], + device_display_name=address, + pushkey=address, lang=None, # We don't know a user's language here data={}, ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index c67bea81c6..469dd53e0c 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) return user_id - def get_user_id_by_threepid_txn(self, txn, medium, address): + def get_user_id_by_threepid_txn( + self, txn, medium: str, address: str + ) -> Optional[str]: """Returns user id from threepid Args: txn (cursor): - medium (str): threepid medium e.g. email - address (str): threepid address e.g. me@example.com + medium: threepid medium e.g. email + address: threepid address e.g. me@example.com Returns: - str|None: user id or None if no user id/threepid mapping exists + user id, or None if no user id/threepid mapping exists """ ret = self.db_pool.simple_select_one_txn( txn, @@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return ret["user_id"] return None - async def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + async def user_add_threepid( + self, + user_id: str, + medium: str, + address: str, + validated_at: int, + added_at: int, + ) -> None: await self.db_pool.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - async def user_get_threepids(self, user_id): + async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]: return await self.db_pool.simple_select_list( "user_threepids", {"user_id": user_id}, @@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "user_get_threepids", ) - async def user_delete_threepid(self, user_id, medium, address) -> None: + async def user_delete_threepid( + self, user_id: str, medium: str, address: str + ) -> None: await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ef77275238..ee204c404b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1431,12 +1431,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual( "external_id1", channel.json_body["external_ids"][0]["external_id"] ) self.assertEqual( "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] ) + self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self._check_fields(channel.json_body) @@ -1676,18 +1678,53 @@ class UserRestTestCase(unittest.HomeserverTestCase): Test setting threepid for an other user. """ - # Delete old and add new threepid to user + # Add two threepids to user channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}, + content={ + "threepids": [ + {"medium": "email", "address": "bob1@bob.bob"}, + {"medium": "email", "address": "bob2@bob.bob"}, + ], + }, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) + # result does not always have the same sort order, therefore it becomes sorted + sorted_result = sorted( + channel.json_body["threepids"], key=lambda k: k["address"] + ) + self.assertEqual("email", sorted_result[0]["medium"]) + self.assertEqual("bob1@bob.bob", sorted_result[0]["address"]) + self.assertEqual("email", sorted_result[1]["medium"]) + self.assertEqual("bob2@bob.bob", sorted_result[1]["address"]) + self._check_fields(channel.json_body) + + # Set a new and remove a threepid + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "threepids": [ + {"medium": "email", "address": "bob2@bob.bob"}, + {"medium": "email", "address": "bob3@bob.bob"}, + ], + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("email", channel.json_body["threepids"][1]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"]) + self._check_fields(channel.json_body) # Get user channel = self.make_request( @@ -1698,8 +1735,24 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("email", channel.json_body["threepids"][1]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"]) + self._check_fields(channel.json_body) + + # Remove threepids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": []}, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self._check_fields(channel.json_body) def test_set_external_id(self): """ @@ -1778,6 +1831,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( channel.json_body["external_ids"], [ -- cgit 1.5.1 From b5fef6054a3d5763a031440ad4c56fc97b12d2aa Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 19 Aug 2021 11:40:40 +0200 Subject: Support MSC3283: Expose `enable_set_displayname` in capabilities (#10452) --- changelog.d/10452.feature | 1 + synapse/config/experimental.py | 3 + synapse/rest/client/capabilities.py | 11 +++ tests/rest/client/v2_alpha/test_capabilities.py | 109 +++++++++++++++++++----- 4 files changed, 101 insertions(+), 23 deletions(-) create mode 100644 changelog.d/10452.feature (limited to 'tests') diff --git a/changelog.d/10452.feature b/changelog.d/10452.feature new file mode 100644 index 0000000000..f332b383e3 --- /dev/null +++ b/changelog.d/10452.feature @@ -0,0 +1 @@ +Add support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283): Expose enable_set_displayname in capabilities. \ No newline at end of file diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b918fb15b0..a85d8a5dc6 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -39,5 +39,8 @@ class ExperimentalConfig(Config): # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", False) + # MSC3283 (set displayname, avatar_url and change 3pid capabilities) + self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False) + # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 88e3aac797..093549512e 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -61,6 +61,17 @@ class CapabilitiesRestServlet(RestServlet): "org.matrix.msc3244.room_capabilities" ] = MSC3244_CAPABILITIES + if self.config.experimental.msc3283_enabled: + response["capabilities"]["org.matrix.msc3283.set_displayname"] = { + "enabled": self.config.enable_set_displayname + } + response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = { + "enabled": self.config.enable_set_avatar_url + } + response["capabilities"]["org.matrix.msc3283.3pid_changes"] = { + "enabled": self.config.enable_3pid_changes + } + return 200, response diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index ad83b3d2ff..ac31e5ceaf 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -30,19 +30,22 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.url = b"/_matrix/client/r0/capabilities" hs = self.setup_test_homeserver() - self.store = hs.get_datastore() self.config = hs.config self.auth_handler = hs.get_auth_handler() return hs + def prepare(self, reactor, clock, hs): + self.localpart = "user" + self.password = "pass" + self.user = self.register_user(self.localpart, self.password) + def test_check_auth_required(self): channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401) def test_get_room_version_capabilities(self): - self.register_user("user", "pass") - access_token = self.login("user", "pass") + access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] @@ -57,10 +60,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): ) def test_get_change_password_capabilities_password_login(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) - access_token = self.login(user, password) + access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] @@ -70,12 +70,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"localdb_enabled": False}}) def test_get_change_password_capabilities_localdb_disabled(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) access_token = self.get_success( self.auth_handler.get_access_token_for_user_id( - user, device_id=None, valid_until_ms=None + self.user, device_id=None, valid_until_ms=None ) ) @@ -87,12 +84,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"enabled": False}}) def test_get_change_password_capabilities_password_disabled(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) access_token = self.get_success( self.auth_handler.get_access_token_for_user_id( - user, device_id=None, valid_until_ms=None + self.user, device_id=None, valid_until_ms=None ) ) @@ -102,13 +96,85 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) + def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self): + """Test that per default msc3283 is disabled server returns `m.change_password`.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities) + self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities) + self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities) + + @override_config({"experimental_features": {"msc3283_enabled": True}}) + def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self): + """Test if msc3283 is enabled server returns capabilities.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + + @override_config( + { + "enable_set_displayname": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_get_set_displayname_capabilities_displayname_disabled(self): + """Test if set displayname is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) + + @override_config( + { + "enable_set_avatar_url": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): + """Test if set avatar_url is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) + + @override_config( + { + "enable_3pid_changes": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_change_3pid_capabilities_3pid_disabled(self): + """Test if change 3pid is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + def test_get_does_not_include_msc3244_fields_by_default(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) access_token = self.get_success( self.auth_handler.get_access_token_for_user_id( - user, device_id=None, valid_until_ms=None + self.user, device_id=None, valid_until_ms=None ) ) @@ -122,12 +188,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"experimental_features": {"msc3244_enabled": True}}) def test_get_does_include_msc3244_fields_when_enabled(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) access_token = self.get_success( self.auth_handler.get_access_token_for_user_id( - user, device_id=None, valid_until_ms=None + self.user, device_id=None, valid_until_ms=None ) ) -- cgit 1.5.1 From 000aa89be63c27092998eca03c97eaead21404cd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 19 Aug 2021 11:12:55 -0400 Subject: Do not include rooms with an unknown room version in a sync response. (#10644) A user will still see this room if it is in a local cache, but it will not reappear if clearing the cache and reloading. --- changelog.d/10644.bugfix | 1 + mypy.ini | 1 + synapse/handlers/sync.py | 7 +- synapse/storage/databases/main/roommember.py | 8 +- synapse/storage/roommember.py | 1 + tests/handlers/test_sync.py | 137 +++++++++++++++++++++++-- tests/replication/slave/storage/test_events.py | 1 + 7 files changed, 145 insertions(+), 11 deletions(-) create mode 100644 changelog.d/10644.bugfix (limited to 'tests') diff --git a/changelog.d/10644.bugfix b/changelog.d/10644.bugfix new file mode 100644 index 0000000000..d88a81fd82 --- /dev/null +++ b/changelog.d/10644.bugfix @@ -0,0 +1 @@ +Rooms with unsupported room versions are no longer returned via `/sync`. diff --git a/mypy.ini b/mypy.ini index 107f4de76c..90ade37b3f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -90,6 +90,7 @@ files = tests/test_utils, tests/handlers/test_password_providers.py, tests/handlers/test_room_summary.py, + tests/handlers/test_sync.py, tests/rest/client/v1/test_login.py, tests/rest/client/v2_alpha/test_auth.py, tests/util/test_itertools.py, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b7b299961f..2203c45dcc 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1,5 +1,4 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2018, 2019 New Vector Ltd +# Copyright 2015-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,6 +30,7 @@ from prometheus_client import Counter from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.filtering import FilterCollection +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span @@ -1843,6 +1843,9 @@ class SyncHandler: knocked = [] for event in room_list: + if event.room_version_id not in KNOWN_ROOM_VERSIONS: + continue + if event.membership == Membership.JOIN: room_entries.append( RoomSyncResultBuilder( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index c2f6b9d63d..c58a4b8690 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -386,9 +386,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) sql = """ - SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering + SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering, r.room_version FROM local_current_membership AS c INNER JOIN events AS e USING (room_id, event_id) + INNER JOIN rooms AS r USING (room_id) WHERE user_id = ? AND %s @@ -397,7 +398,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)] + results = [RoomsForUser(*r) for r in txn] return results @@ -447,7 +448,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): Returns: Returns the rooms the user is in currently, along with the stream - ordering of the most recent join for that user and room. + ordering of the most recent join for that user and room, along with + the room version of the room. """ return await self.db_pool.runInteraction( "get_rooms_for_user_with_stream_ordering", diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 9fad67ce48..2500381b7b 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -30,6 +30,7 @@ class RoomsForUser: membership: str event_id: str stream_ordering: int + room_version_id: str @attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 84f05f6c58..339c039914 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -12,9 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION +from synapse.api.room_versions import RoomVersions from synapse.handlers.sync import SyncConfig +from synapse.rest import admin +from synapse.rest.client import knock, login, room +from synapse.server import HomeServer from synapse.types import UserID, create_requester import tests.unittest @@ -24,8 +31,14 @@ import tests.utils class SyncTestCase(tests.unittest.HomeserverTestCase): """Tests Sync Handler.""" - def prepare(self, reactor, clock, hs): - self.hs = hs + servlets = [ + admin.register_servlets, + knock.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastore() @@ -68,12 +81,124 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + def test_unknown_room_version(self): + """ + A room with an unknown room version should not break sync (and should be excluded). + """ + inviter = self.register_user("creator", "pass", admin=True) + inviter_tok = self.login("@creator:test", "pass") + + user = self.register_user("user", "pass") + tok = self.login("user", "pass") + + # Do an initial sync on a different device. + requester = create_requester(user) + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, sync_config=generate_sync_config(user, device_id="dev") + ) + ) + + # Create a room as the user. + joined_room = self.helper.create_room_as(user, tok=tok) + + # Invite the user to the room as someone else. + invite_room = self.helper.create_room_as(inviter, tok=inviter_tok) + self.helper.invite(invite_room, targ=user, tok=inviter_tok) + + knock_room = self.helper.create_room_as( + inviter, room_version=RoomVersions.V7.identifier, tok=inviter_tok + ) + self.helper.send_state( + knock_room, + EventTypes.JoinRules, + {"join_rule": JoinRules.KNOCK}, + tok=inviter_tok, + ) + channel = self.make_request( + "POST", + "/_matrix/client/r0/knock/%s" % (knock_room,), + b"{}", + tok, + ) + self.assertEquals(200, channel.code, channel.result) + + # The rooms should appear in the sync response. + result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, sync_config=generate_sync_config(user) + ) + ) + self.assertIn(joined_room, [r.room_id for r in result.joined]) + self.assertIn(invite_room, [r.room_id for r in result.invited]) + self.assertIn(knock_room, [r.room_id for r in result.knocked]) + + # Test a incremental sync (by providing a since_token). + result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config(user, device_id="dev"), + since_token=initial_result.next_batch, + ) + ) + self.assertIn(joined_room, [r.room_id for r in result.joined]) + self.assertIn(invite_room, [r.room_id for r in result.invited]) + self.assertIn(knock_room, [r.room_id for r in result.knocked]) + + # Poke the database and update the room version to an unknown one. + for room_id in (joined_room, invite_room, knock_room): + self.get_success( + self.hs.get_datastores().main.db_pool.simple_update( + "rooms", + keyvalues={"room_id": room_id}, + updatevalues={"room_version": "unknown-room-version"}, + desc="updated-room-version", + ) + ) + + # Blow away caches (supported room versions can only change due to a restart). + self.get_success( + self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() + ) + self.store._get_event_cache.clear() + + # The rooms should be excluded from the sync response. + # Get a new request key. + result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, sync_config=generate_sync_config(user) + ) + ) + self.assertNotIn(joined_room, [r.room_id for r in result.joined]) + self.assertNotIn(invite_room, [r.room_id for r in result.invited]) + self.assertNotIn(knock_room, [r.room_id for r in result.knocked]) + + # The rooms should also not be in an incremental sync. + result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config(user, device_id="dev"), + since_token=initial_result.next_batch, + ) + ) + self.assertNotIn(joined_room, [r.room_id for r in result.joined]) + self.assertNotIn(invite_room, [r.room_id for r in result.invited]) + self.assertNotIn(knock_room, [r.room_id for r in result.knocked]) + + +_request_key = 0 + -def generate_sync_config(user_id: str) -> SyncConfig: +def generate_sync_config( + user_id: str, device_id: Optional[str] = "device_id" +) -> SyncConfig: + """Generate a sync config (with a unique request key).""" + global _request_key + _request_key += 1 return SyncConfig( - user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), + user=UserID.from_string(user_id), filter_collection=DEFAULT_FILTER_COLLECTION, is_guest=False, - request_key="request_key", - device_id="device_id", + request_key=("request_key", _request_key), + device_id=device_id, ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 8d1b0606c4..b25a06b427 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -150,6 +150,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): "invite", event.event_id, event.internal_metadata.stream_ordering, + RoomVersions.V1.identifier, ) ], ) -- cgit 1.5.1 From e81d62009ee091d69d56bca702ba0edd39becb1c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 19 Aug 2021 18:05:12 +0100 Subject: Split `on_receive_pdu` in half (#10640) Here we split on_receive_pdu into two functions (on_receive_pdu and process_pulled_event), rather than having both cases in the same method. There's a tiny bit of overlap, but not that much. --- changelog.d/10640.misc | 1 + synapse/federation/federation_server.py | 4 +- synapse/handlers/federation.py | 236 +++++++++++++++++++------------- tests/test_federation.py | 10 +- 4 files changed, 142 insertions(+), 109 deletions(-) create mode 100644 changelog.d/10640.misc (limited to 'tests') diff --git a/changelog.d/10640.misc b/changelog.d/10640.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10640.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index afd8f8580a..e1b58d40c5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1005,9 +1005,7 @@ class FederationServer(FederationBase): async with lock: logger.info("handling received PDU: %s", event) try: - await self.handler.on_receive_pdu( - origin, event, sent_to_us_directly=True - ) + await self.handler.on_receive_pdu(origin, event) except FederationError as e: # XXX: Ideally we'd inform the remote we failed to process # the event, but we can't return an error in the transaction diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index de86918b7b..246df43501 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -203,18 +203,13 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - async def on_receive_pdu( - self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False - ) -> None: - """Process a PDU received via a federation /send/ transaction, or - via backfill of missing prev_events + async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: + """Process a PDU received via a federation /send/ transaction Args: origin: server which initiated the /send/ transaction. Will be used to fetch missing events or state. pdu: received PDU - sent_to_us_directly: True if this event was pushed to us; False if - we pulled it as the result of a missing prev_event. """ room_id = pdu.room_id @@ -276,8 +271,6 @@ class FederationHandler(BaseHandler): ) return None - state = None - # Check that the event passes auth based on the state at the event. This is # done for events that are to be added to the timeline (non-outliers). # @@ -285,83 +278,72 @@ class FederationHandler(BaseHandler): # - Fetching any missing prev events to fill in gaps in the graph # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): - if sent_to_us_directly: - prevs = set(pdu.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if missing_prevs: - # We only backfill backwards to the min depth. - min_depth = await self.get_min_depth_for_context(pdu.room_id) - logger.debug("min_depth: %d", min_depth) - - if min_depth is not None and pdu.depth > min_depth: - # If we're missing stuff, ensure we only fetch stuff one - # at a time. + prevs = set(pdu.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if missing_prevs: + # We only backfill backwards to the min depth. + min_depth = await self.get_min_depth_for_context(pdu.room_id) + logger.debug("min_depth: %d", min_depth) + + if min_depth is not None and pdu.depth > min_depth: + # If we're missing stuff, ensure we only fetch stuff one + # at a time. + logger.info( + "Acquiring room lock to fetch %d missing prev_events: %s", + len(missing_prevs), + shortstr(missing_prevs), + ) + with (await self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( - "Acquiring room lock to fetch %d missing prev_events: %s", + "Acquired room lock to fetch %d missing prev_events", len(missing_prevs), - shortstr(missing_prevs), ) - with (await self._room_pdu_linearizer.queue(pdu.room_id)): - logger.info( - "Acquired room lock to fetch %d missing prev_events", - len(missing_prevs), + + try: + await self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth ) + except Exception as e: + raise Exception( + "Error fetching missing prev_events for %s: %s" + % (event_id, e) + ) from e - try: - await self._get_missing_events_for_pdu( - origin, pdu, prevs, min_depth - ) - except Exception as e: - raise Exception( - "Error fetching missing prev_events for %s: %s" - % (event_id, e) - ) from e - - # Update the set of things we've seen after trying to - # fetch the missing stuff - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if not missing_prevs: - logger.info("Found all missing prev_events") - - if missing_prevs: - # since this event was pushed to us, it is possible for it to - # become the only forward-extremity in the room, and we would then - # trust its state to be the state for the whole room. This is very - # bad. Further, if the event was pushed to us, there is no excuse - # for us not to have all the prev_events. (XXX: apart from - # min_depth?) - # - # We therefore reject any such events. - logger.warning( - "Rejecting: failed to fetch %d prev events: %s", - len(missing_prevs), - shortstr(missing_prevs), - ) - raise FederationError( - "ERROR", - 403, - ( - "Your server isn't divulging details about prev_events " - "referenced in this event." - ), - affected=pdu.event_id, - ) + # Update the set of things we've seen after trying to + # fetch the missing stuff + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen - else: - state = await self._resolve_state_at_missing_prevs(origin, pdu) + if not missing_prevs: + logger.info("Found all missing prev_events") + + if missing_prevs: + # since this event was pushed to us, it is possible for it to + # become the only forward-extremity in the room, and we would then + # trust its state to be the state for the whole room. This is very + # bad. Further, if the event was pushed to us, there is no excuse + # for us not to have all the prev_events. (XXX: apart from + # min_depth?) + # + # We therefore reject any such events. + logger.warning( + "Rejecting: failed to fetch %d prev events: %s", + len(missing_prevs), + shortstr(missing_prevs), + ) + raise FederationError( + "ERROR", + 403, + ( + "Your server isn't divulging details about prev_events " + "referenced in this event." + ), + affected=pdu.event_id, + ) - # A second round of checks for all events. Check that the event passes auth - # based on `auth_events`, this allows us to assert that the event would - # have been allowed at some point. If an event passes this check its OK - # for it to be used as part of a returned `/state` request, as either - # a) we received the event as part of the original join and so trust it, or - # b) we'll do a state resolution with existing state before it becomes - # part of the "current state", which adds more protection. - await self._process_received_pdu(origin, pdu, state=state) + await self._process_received_pdu(origin, pdu, state=None) async def _get_missing_events_for_pdu( self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int @@ -461,24 +443,7 @@ class FederationHandler(BaseHandler): return logger.info("Got %d prev_events", len(missing_events)) - - # We want to sort these by depth so we process them and - # tell clients about them in order. - missing_events.sort(key=lambda x: x.depth) - - for ev in missing_events: - logger.info("Handling received prev_event %s", ev) - with nested_logging_context(ev.event_id): - try: - await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) - except FederationError as e: - if e.code == 403: - logger.warning( - "Received prev_event %s failed history check.", - ev.event_id, - ) - else: - raise + await self._process_pulled_events(origin, missing_events) async def _get_state_for_room( self, @@ -1395,6 +1360,81 @@ class FederationHandler(BaseHandler): event_infos, ) + async def _process_pulled_events( + self, origin: str, events: Iterable[EventBase] + ) -> None: + """Process a batch of events we have pulled from a remote server + + Pulls in any events required to auth the events, persists the received events, + and notifies clients, if appropriate. + + Assumes the events have already had their signatures and hashes checked. + + Params: + origin: The server we received these events from + events: The received events. + """ + + # We want to sort these by depth so we process them and + # tell clients about them in order. + sorted_events = sorted(events, key=lambda x: x.depth) + + for ev in sorted_events: + with nested_logging_context(ev.event_id): + await self._process_pulled_event(origin, ev) + + async def _process_pulled_event(self, origin: str, event: EventBase) -> None: + """Process a single event that we have pulled from a remote server + + Pulls in any events required to auth the event, persists the received event, + and notifies clients, if appropriate. + + Assumes the event has already had its signatures and hashes checked. + + This is somewhat equivalent to on_receive_pdu, but applies somewhat different + logic in the case that we are missing prev_events (in particular, it just + requests the state at that point, rather than triggering a get_missing_events) - + so is appropriate when we have pulled the event from a remote server, rather + than having it pushed to us. + + Params: + origin: The server we received this event from + events: The received event + """ + logger.info("Processing pulled event %s", event) + + # these should not be outliers. + assert not event.internal_metadata.is_outlier() + + event_id = event.event_id + + existing = await self.store.get_event( + event_id, allow_none=True, allow_rejected=True + ) + if existing: + if not existing.internal_metadata.is_outlier(): + logger.info( + "Ignoring received event %s which we have already seen", + event_id, + ) + return + logger.info("De-outliering event %s", event_id) + + try: + self._sanity_check_event(event) + except SynapseError as err: + logger.warning("Event %s failed sanity check: %s", event_id, err) + return + + try: + state = await self._resolve_state_at_missing_prevs(origin, event) + await self._process_received_pdu(origin, event, state=state) + except FederationError as e: + if e.code == 403: + logger.warning("Pulled event %s failed history check.", event_id) + else: + raise + async def _resolve_state_at_missing_prevs( self, dest: str, event: EventBase ) -> Optional[Iterable[EventBase]]: @@ -1780,7 +1820,7 @@ class FederationHandler(BaseHandler): p, ) with nested_logging_context(p.event_id): - await self.on_receive_pdu(origin, p, sent_to_us_directly=True) + await self.on_receive_pdu(origin, p) except Exception as e: logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e diff --git a/tests/test_federation.py b/tests/test_federation.py index 3785799f46..348fcb72a7 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -85,11 +85,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Send the join, it should return None (which is not an error) self.assertEqual( - self.get_success( - self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True - ) - ), + self.get_success(self.handler.on_receive_pdu("test.serv", join_event)), None, ) @@ -135,9 +131,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): with LoggingContext("test-context"): failure = self.get_failure( - self.handler.on_receive_pdu( - "test.serv", lying_event, sent_to_us_directly=True - ), + self.handler.on_receive_pdu("test.serv", lying_event), FederationError, ) -- cgit 1.5.1 From ee3b2ac59a5645cb213b3c11c473613b80fe1e0f Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 20 Aug 2021 15:47:03 +0100 Subject: Validate device_keys for C-S /keys/query requests (#10593) * Validate device_keys for C-S /keys/query requests Closes #10354 A small, not particularly critical fix. I'm interested in seeing if we can find a more systematic approach though. #8445 is the place for any discussion. --- changelog.d/10593.bugfix | 1 + synapse/api/errors.py | 8 ++++ synapse/rest/client/keys.py | 16 ++++++- tests/rest/client/v2_alpha/test_keys.py | 77 +++++++++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10593.bugfix create mode 100644 tests/rest/client/v2_alpha/test_keys.py (limited to 'tests') diff --git a/changelog.d/10593.bugfix b/changelog.d/10593.bugfix new file mode 100644 index 0000000000..492e58a7a8 --- /dev/null +++ b/changelog.d/10593.bugfix @@ -0,0 +1 @@ +Reject Client-Server /keys/query requests which provide device_ids incorrectly. \ No newline at end of file diff --git a/synapse/api/errors.py b/synapse/api/errors.py index dc662bca83..9480f448d7 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -147,6 +147,14 @@ class SynapseError(CodeMessageException): return cs_error(self.msg, self.errcode) +class InvalidAPICallError(SynapseError): + """You called an existing API endpoint, but fed that endpoint + invalid or incomplete data.""" + + def __init__(self, msg: str): + super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON) + + class ProxiedRequestError(SynapseError): """An error from a general matrix endpoint, eg. from a proxied Matrix API call. diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index d0d9d30d40..012491f597 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -15,8 +15,9 @@ # limitations under the License. import logging +from typing import Any -from synapse.api.errors import SynapseError +from synapse.api.errors import InvalidAPICallError, SynapseError from synapse.http.servlet import ( RestServlet, parse_integer, @@ -163,6 +164,19 @@ class KeyQueryServlet(RestServlet): device_id = requester.device_id timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) + + device_keys = body.get("device_keys") + if not isinstance(device_keys, dict): + raise InvalidAPICallError("'device_keys' must be a JSON object") + + def is_list_of_strings(values: Any) -> bool: + return isinstance(values, list) and all(isinstance(v, str) for v in values) + + if any(not is_list_of_strings(keys) for keys in device_keys.values()): + raise InvalidAPICallError( + "'device_keys' values must be a list of strings", + ) + result = await self.e2e_keys_handler.query_devices( body, timeout, user_id, device_id ) diff --git a/tests/rest/client/v2_alpha/test_keys.py b/tests/rest/client/v2_alpha/test_keys.py new file mode 100644 index 0000000000..80a4e728ff --- /dev/null +++ b/tests/rest/client/v2_alpha/test_keys.py @@ -0,0 +1,77 @@ +from http import HTTPStatus + +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client import keys, login + +from tests import unittest + + +class KeyQueryTestCase(unittest.HomeserverTestCase): + servlets = [ + keys.register_servlets, + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def test_rejects_device_id_ice_key_outside_of_list(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: "device_id1", + }, + }, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_rejects_device_key_given_as_map_to_bool(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: { + "device_id1": True, + }, + }, + }, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_requires_device_key(self): + """`device_keys` is required. We should complain if it's missing.""" + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + {}, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) -- cgit 1.5.1 From 7862d704fdececf5a3a10466e2e7aee81fdf10f4 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 20 Aug 2021 16:31:02 +0100 Subject: Follow-up: format changelog, add licence (#10593) Merged before approval; these comments from @clokep on that PR. --- changelog.d/10593.bugfix | 2 +- tests/rest/client/v2_alpha/test_keys.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/changelog.d/10593.bugfix b/changelog.d/10593.bugfix index 492e58a7a8..af910bfa4d 100644 --- a/changelog.d/10593.bugfix +++ b/changelog.d/10593.bugfix @@ -1 +1 @@ -Reject Client-Server /keys/query requests which provide device_ids incorrectly. \ No newline at end of file +Reject Client-Server `/keys/query` requests which provide `device_ids` incorrectly. diff --git a/tests/rest/client/v2_alpha/test_keys.py b/tests/rest/client/v2_alpha/test_keys.py index 80a4e728ff..d7fa635eae 100644 --- a/tests/rest/client/v2_alpha/test_keys.py +++ b/tests/rest/client/v2_alpha/test_keys.py @@ -1,3 +1,17 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + from http import HTTPStatus from synapse.api.errors import Codes -- cgit 1.5.1 From f499dc38bcdf0cec83b873923af3d5310538759e Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 20 Aug 2021 17:43:26 +0200 Subject: Simplify tests for the device admin rest API. (#10664) By replacing duplicated code with parameterized tests and avoiding unnecessary dumping of JSON data. --- changelog.d/10664.misc | 1 + tests/rest/admin/test_device.py | 99 ++++++++--------------------------------- 2 files changed, 19 insertions(+), 81 deletions(-) create mode 100644 changelog.d/10664.misc (limited to 'tests') diff --git a/changelog.d/10664.misc b/changelog.d/10664.misc new file mode 100644 index 0000000000..cebd5e9a96 --- /dev/null +++ b/changelog.d/10664.misc @@ -0,0 +1 @@ +Simplify tests for device admin rest API. \ No newline at end of file diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index c4afe5c3d9..a3679be205 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import urllib.parse +from parameterized import parameterized + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login @@ -45,49 +46,23 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.other_user_device_id, ) - def test_no_auth(self): + @parameterized.expand(["GET", "PUT", "DELETE"]) + def test_no_auth(self, method: str): """ Try to get a device of an user without authentication. """ - channel = self.make_request("GET", self.url, b"{}") - - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("PUT", self.url, b"{}") - - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("DELETE", self.url, b"{}") + channel = self.make_request(method, self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + @parameterized.expand(["GET", "PUT", "DELETE"]) + def test_requester_is_no_admin(self, method: str): """ If the user is not a server admin, an error is returned. """ channel = self.make_request( - "GET", - self.url, - access_token=self.other_user_token, - ) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "PUT", - self.url, - access_token=self.other_user_token, - ) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "DELETE", + method, self.url, access_token=self.other_user_token, ) @@ -95,7 +70,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + @parameterized.expand(["GET", "PUT", "DELETE"]) + def test_user_does_not_exist(self, method: str): """ Tests that a lookup for a user that does not exist returns a 404 """ @@ -105,7 +81,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) channel = self.make_request( - "GET", + method, url, access_token=self.admin_user_tok, ) @@ -113,25 +89,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - channel = self.make_request( - "PUT", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - channel = self.make_request( - "DELETE", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - def test_user_is_not_local(self): + @parameterized.expand(["GET", "PUT", "DELETE"]) + def test_user_is_not_local(self, method: str): """ Tests that a lookup for a user that is not a local returns a 400 """ @@ -141,25 +100,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) - - channel = self.make_request( - "PUT", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) - - channel = self.make_request( - "DELETE", + method, url, access_token=self.admin_user_tok, ) @@ -219,12 +160,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1) } - body = json.dumps(update) channel = self.make_request( "PUT", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=update, ) self.assertEqual(400, channel.code, msg=channel.json_body) @@ -275,12 +215,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): Tests a normal successful update of display name """ # Set new display_name - body = json.dumps({"display_name": "new displayname"}) channel = self.make_request( "PUT", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"display_name": "new displayname"}, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -529,12 +468,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ Tests that a remove of a device that does not exist returns 200. """ - body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]}) channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"devices": ["unknown_device1", "unknown_device2"]}, ) # Delete unknown devices returns status 200 @@ -560,12 +498,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): device_ids.append(str(d["device_id"])) # Delete devices - body = json.dumps({"devices": device_ids}) channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"devices": device_ids}, ) self.assertEqual(200, channel.code, msg=channel.json_body) -- cgit 1.5.1 From ecd823d766fecdd7e1c7163073c097a0084122e2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 20 Aug 2021 17:50:44 +0100 Subject: Flatten tests/rest/client/{v1,v2_alpha} too (#10667) --- changelog.d/10667.misc | 1 + mypy.ini | 4 +- tests/rest/client/test_account.py | 1000 +++++++++ tests/rest/client/test_auth.py | 717 +++++++ tests/rest/client/test_capabilities.py | 212 ++ tests/rest/client/test_directory.py | 154 ++ tests/rest/client/test_events.py | 156 ++ tests/rest/client/test_filter.py | 111 + tests/rest/client/test_keys.py | 91 + tests/rest/client/test_login.py | 1345 ++++++++++++ tests/rest/client/test_password_policy.py | 177 ++ tests/rest/client/test_presence.py | 76 + tests/rest/client/test_profile.py | 270 +++ tests/rest/client/test_push_rule_attrs.py | 414 ++++ tests/rest/client/test_register.py | 746 +++++++ tests/rest/client/test_relations.py | 724 +++++++ tests/rest/client/test_report_event.py | 82 + tests/rest/client/test_rooms.py | 2150 ++++++++++++++++++++ tests/rest/client/test_sendtodevice.py | 200 ++ tests/rest/client/test_shared_rooms.py | 142 ++ tests/rest/client/test_sync.py | 686 +++++++ tests/rest/client/test_typing.py | 121 ++ tests/rest/client/test_upgrade_room.py | 159 ++ tests/rest/client/utils.py | 654 ++++++ tests/rest/client/v1/__init__.py | 13 - tests/rest/client/v1/test_directory.py | 154 -- tests/rest/client/v1/test_events.py | 156 -- tests/rest/client/v1/test_login.py | 1345 ------------ tests/rest/client/v1/test_presence.py | 76 - tests/rest/client/v1/test_profile.py | 270 --- tests/rest/client/v1/test_push_rule_attrs.py | 414 ---- tests/rest/client/v1/test_rooms.py | 2150 -------------------- tests/rest/client/v1/test_typing.py | 121 -- tests/rest/client/v1/utils.py | 654 ------ tests/rest/client/v2_alpha/__init__.py | 0 tests/rest/client/v2_alpha/test_account.py | 1000 --------- tests/rest/client/v2_alpha/test_auth.py | 717 ------- tests/rest/client/v2_alpha/test_capabilities.py | 212 -- tests/rest/client/v2_alpha/test_filter.py | 111 - tests/rest/client/v2_alpha/test_keys.py | 91 - tests/rest/client/v2_alpha/test_password_policy.py | 177 -- tests/rest/client/v2_alpha/test_register.py | 746 ------- tests/rest/client/v2_alpha/test_relations.py | 724 ------- tests/rest/client/v2_alpha/test_report_event.py | 82 - tests/rest/client/v2_alpha/test_sendtodevice.py | 200 -- tests/rest/client/v2_alpha/test_shared_rooms.py | 142 -- tests/rest/client/v2_alpha/test_sync.py | 686 ------- tests/rest/client/v2_alpha/test_upgrade_room.py | 159 -- tests/unittest.py | 2 +- 49 files changed, 10391 insertions(+), 10403 deletions(-) create mode 100644 changelog.d/10667.misc create mode 100644 tests/rest/client/test_account.py create mode 100644 tests/rest/client/test_auth.py create mode 100644 tests/rest/client/test_capabilities.py create mode 100644 tests/rest/client/test_directory.py create mode 100644 tests/rest/client/test_events.py create mode 100644 tests/rest/client/test_filter.py create mode 100644 tests/rest/client/test_keys.py create mode 100644 tests/rest/client/test_login.py create mode 100644 tests/rest/client/test_password_policy.py create mode 100644 tests/rest/client/test_presence.py create mode 100644 tests/rest/client/test_profile.py create mode 100644 tests/rest/client/test_push_rule_attrs.py create mode 100644 tests/rest/client/test_register.py create mode 100644 tests/rest/client/test_relations.py create mode 100644 tests/rest/client/test_report_event.py create mode 100644 tests/rest/client/test_rooms.py create mode 100644 tests/rest/client/test_sendtodevice.py create mode 100644 tests/rest/client/test_shared_rooms.py create mode 100644 tests/rest/client/test_sync.py create mode 100644 tests/rest/client/test_typing.py create mode 100644 tests/rest/client/test_upgrade_room.py create mode 100644 tests/rest/client/utils.py delete mode 100644 tests/rest/client/v1/__init__.py delete mode 100644 tests/rest/client/v1/test_directory.py delete mode 100644 tests/rest/client/v1/test_events.py delete mode 100644 tests/rest/client/v1/test_login.py delete mode 100644 tests/rest/client/v1/test_presence.py delete mode 100644 tests/rest/client/v1/test_profile.py delete mode 100644 tests/rest/client/v1/test_push_rule_attrs.py delete mode 100644 tests/rest/client/v1/test_rooms.py delete mode 100644 tests/rest/client/v1/test_typing.py delete mode 100644 tests/rest/client/v1/utils.py delete mode 100644 tests/rest/client/v2_alpha/__init__.py delete mode 100644 tests/rest/client/v2_alpha/test_account.py delete mode 100644 tests/rest/client/v2_alpha/test_auth.py delete mode 100644 tests/rest/client/v2_alpha/test_capabilities.py delete mode 100644 tests/rest/client/v2_alpha/test_filter.py delete mode 100644 tests/rest/client/v2_alpha/test_keys.py delete mode 100644 tests/rest/client/v2_alpha/test_password_policy.py delete mode 100644 tests/rest/client/v2_alpha/test_register.py delete mode 100644 tests/rest/client/v2_alpha/test_relations.py delete mode 100644 tests/rest/client/v2_alpha/test_report_event.py delete mode 100644 tests/rest/client/v2_alpha/test_sendtodevice.py delete mode 100644 tests/rest/client/v2_alpha/test_shared_rooms.py delete mode 100644 tests/rest/client/v2_alpha/test_sync.py delete mode 100644 tests/rest/client/v2_alpha/test_upgrade_room.py (limited to 'tests') diff --git a/changelog.d/10667.misc b/changelog.d/10667.misc new file mode 100644 index 0000000000..c92846ae26 --- /dev/null +++ b/changelog.d/10667.misc @@ -0,0 +1 @@ +Flatten the `tests.synapse.rests` package by moving the contents of `v1` and `v2_alpha` into the parent. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 90ade37b3f..b17872211e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -91,8 +91,8 @@ files = tests/handlers/test_password_providers.py, tests/handlers/test_room_summary.py, tests/handlers/test_sync.py, - tests/rest/client/v1/test_login.py, - tests/rest/client/v2_alpha/test_auth.py, + tests/rest/client/test_login.py, + tests/rest/client/test_auth.py, tests/util/test_itertools.py, tests/util/test_stream_change_cache.py diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py new file mode 100644 index 0000000000..b946fca8b3 --- /dev/null +++ b/tests/rest/client/test_account.py @@ -0,0 +1,1000 @@ +# Copyright 2015-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import re +from email.parser import Parser +from typing import Optional + +import pkg_resources + +import synapse.rest.admin +from synapse.api.constants import LoginType, Membership +from synapse.api.errors import Codes, HttpResponseException +from synapse.appservice import ApplicationService +from synapse.rest.client import account, login, register, room +from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource + +from tests import unittest +from tests.server import FakeSite, make_request +from tests.unittest import override_config + + +class PasswordResetTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + hs.get_send_email_handler()._sendmail = sendmail + + return hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.submit_token_resource = PasswordResetSubmitTokenResource(hs) + + def test_basic_password_reset(self): + """Test basic password reset flow""" + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + # Assert we can log in with the new password + self.login("kermit", new_password) + + # Assert we can't log in with the old password + self.attempt_wrong_password_login("kermit", old_password) + + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_email(self): + """Test that we ratelimit /requestToken for the same email.""" + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test1@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + def reset(ip): + client_secret = "foobar" + session_id = self._request_token(email, client_secret, ip) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + self.email_attempts.clear() + + # We expect to be able to make three requests before getting rate + # limited. + # + # We change IPs to ensure that we're not being ratelimited due to the + # same IP + reset("127.0.0.1") + reset("127.0.0.2") + reset("127.0.0.3") + + with self.assertRaises(HttpResponseException) as cm: + reset("127.0.0.4") + + self.assertEqual(cm.exception.code, 429) + + def test_basic_password_reset_canonicalise_email(self): + """Test basic password reset flow + Request password reset with different spelling + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email_profile = "test@example.com" + email_passwort_reset = "TEST@EXAMPLE.COM" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email_profile, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email_passwort_reset, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + # Assert we can log in with the new password + self.login("kermit", new_password) + + # Assert we can't log in with the old password + self.attempt_wrong_password_login("kermit", old_password) + + def test_cant_reset_password_without_clicking_link(self): + """Test that we do actually need to click the link in the email""" + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to reset password without clicking the link + self._reset_password(new_password, session_id, client_secret, expected_code=401) + + # Assert we can log in with the old password + self.login("kermit", old_password) + + # Assert we can't log in with the new password + self.attempt_wrong_password_login("kermit", new_password) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = "weasle" + + # Attempt to reset password without even requesting an email + self._reset_password(new_password, session_id, client_secret, expected_code=401) + + # Assert we can log in with the old password + self.login("kermit", old_password) + + # Assert we can't log in with the new password + self.attempt_wrong_password_login("kermit", new_password) + + @unittest.override_config({"request_token_inhibit_3pid_errors": True}) + def test_password_reset_bad_email_inhibit_error(self): + """Test that triggering a password reset with an email address that isn't bound + to an account doesn't leak the lack of binding for that address if configured + that way. + """ + self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertIsNotNone(session_id) + + def _request_token(self, email, client_secret, ip="127.0.0.1"): + channel = self.make_request( + "POST", + b"account/password/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + client_ip=ip, + ) + + if channel.code != 200: + raise HttpResponseException( + channel.code, + channel.result["reason"], + channel.result["body"], + ) + + return channel.json_body["sid"] + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + # Load the password reset confirmation page + channel = make_request( + self.reactor, + FakeSite(self.submit_token_resource), + "GET", + path, + shorthand=False, + ) + + self.assertEquals(200, channel.code, channel.result) + + # Now POST to the same endpoint, mimicking the same behaviour as clicking the + # password reset confirm button + + # Confirm the password reset + channel = make_request( + self.reactor, + FakeSite(self.submit_token_resource), + "POST", + path, + content=b"", + shorthand=False, + content_is_form=True, + ) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) + + def _reset_password( + self, new_password, session_id, client_secret, expected_code=200 + ): + channel = self.make_request( + "POST", + b"account/password", + { + "new_password": new_password, + "auth": { + "type": LoginType.EMAIL_IDENTITY, + "threepid_creds": { + "client_secret": client_secret, + "sid": session_id, + }, + }, + }, + ) + self.assertEquals(expected_code, channel.code, channel.result) + + +class DeactivateTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + account.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + return self.hs + + def test_deactivate_account(self): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + self.deactivate(user_id, tok) + + store = self.hs.get_datastore() + + # Check that the user has been marked as deactivated. + self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) + + # Check that this access token has been invalidated. + channel = self.make_request("GET", "account/whoami", access_token=tok) + self.assertEqual(channel.code, 401) + + def test_pending_invites(self): + """Tests that deactivating a user rejects every pending invite for them.""" + store = self.hs.get_datastore() + + inviter_id = self.register_user("inviter", "test") + inviter_tok = self.login("inviter", "test") + + invitee_id = self.register_user("invitee", "test") + invitee_tok = self.login("invitee", "test") + + # Make @inviter:test invite @invitee:test in a new room. + room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok) + self.helper.invite( + room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok + ) + + # Make sure the invite is here. + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) + self.assertEqual(len(pending_invites), 1, pending_invites) + self.assertEqual(pending_invites[0].room_id, room_id, pending_invites) + + # Deactivate @invitee:test. + self.deactivate(invitee_id, invitee_tok) + + # Check that the invite isn't there anymore. + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) + self.assertEqual(len(pending_invites), 0, pending_invites) + + # Check that the membership of @invitee:test in the room is now "leave". + memberships = self.get_success( + store.get_rooms_for_local_user_where_membership_is( + invitee_id, [Membership.LEAVE] + ) + ) + self.assertEqual(len(memberships), 1, memberships) + self.assertEqual(memberships[0].room_id, room_id, memberships) + + def deactivate(self, user_id, tok): + request_data = json.dumps( + { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "test", + }, + "erase": False, + } + ) + channel = self.make_request( + "POST", "account/deactivate", request_data, access_token=tok + ) + self.assertEqual(channel.code, 200) + + +class WhoamiTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + account.register_servlets, + register.register_servlets, + ] + + def test_GET_whoami(self): + device_id = "wouldgohere" + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test", device_id=device_id) + + whoami = self.whoami(tok) + self.assertEqual(whoami, {"user_id": user_id, "device_id": device_id}) + + def test_GET_whoami_appservices(self): + user_id = "@as:test" + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + self.hs.config.server_name, + id="1234", + namespaces={"users": [{"regex": user_id, "exclusive": True}]}, + sender=user_id, + ) + self.hs.get_datastore().services_cache.append(appservice) + + whoami = self.whoami(as_token) + self.assertEqual(whoami, {"user_id": user_id}) + self.assertFalse(hasattr(whoami, "device_id")) + + def whoami(self, tok): + channel = self.make_request("GET", "account/whoami", {}, access_token=tok) + self.assertEqual(channel.code, 200) + return channel.json_body + + +class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail + + return self.hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.user_id = self.register_user("kermit", "test") + self.user_id_tok = self.login("kermit", "test") + self.email = "test@example.com" + self.url_3pid = b"account/3pid" + + def test_add_valid_email(self): + self.get_success(self._add_email(self.email, self.email)) + + def test_add_valid_email_second_time(self): + self.get_success(self._add_email(self.email, self.email)) + self.get_success( + self._request_token_invalid_email( + self.email, + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", + ) + ) + + def test_add_valid_email_second_time_canonicalise(self): + self.get_success(self._add_email(self.email, self.email)) + self.get_success( + self._request_token_invalid_email( + "TEST@EXAMPLE.COM", + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", + ) + ) + + def test_add_email_no_at(self): + self.get_success( + self._request_token_invalid_email( + "address-without-at.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) + ) + + def test_add_email_two_at(self): + self.get_success( + self._request_token_invalid_email( + "foo@foo@test.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) + ) + + def test_add_email_bad_format(self): + self.get_success( + self._request_token_invalid_email( + "user@bad.example.net@good.example.com", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) + ) + + def test_add_email_domain_to_lower(self): + self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) + + def test_add_email_domain_with_umlaut(self): + self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")) + + def test_add_email_address_casefold(self): + self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) + + def test_address_trim(self): + self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) + + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_ip(self): + """Tests that adding emails is ratelimited by IP""" + + # We expect to be able to set three emails before getting ratelimited. + self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) + self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) + self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) + + with self.assertRaises(HttpResponseException) as cm: + self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) + + self.assertEqual(cm.exception.code, 429) + + def test_add_email_if_disabled(self): + """Test adding email to profile when doing so is disallowed""" + self.hs.config.enable_3pid_changes = False + + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email(self): + """Test deleting an email from profile""" + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email_if_disabled(self): + """Test deleting an email from profile when disallowed""" + self.hs.config.enable_3pid_changes = False + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_cant_add_email_without_clicking_link(self): + """Test that we do actually need to click the link in the email""" + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to add email without clicking the link + channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + client_secret = "foobar" + session_id = "weasle" + + # Attempt to add email without even requesting an email + channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link(self): + """Tests a valid next_link parameter value with no whitelist (good case)""" + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/good/site", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_exotic_protocol(self): + """Tests using a esoteric protocol as a next_link parameter value. + Someone may be hosting a client on IPFS etc. + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="some-protocol://abcdefghijklmopqrstuvwxyz", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_file_uri(self): + """Tests next_link parameters cannot be file URI""" + # Attempt to use a next_link value that points to the local disk + self._request_token( + "something@example.com", + "some_secret", + next_link="file:///host/path", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) + def test_next_link_domain_whitelist(self): + """Tests next_link parameters must fit the whitelist if provided""" + + # Ensure not providing a next_link parameter still works + self._request_token( + "something@example.com", + "some_secret", + next_link=None, + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/some/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.org/some/also/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://bad.example.org/some/bad/page", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": []}) + def test_empty_next_link_domain_whitelist(self): + """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially + disallowed + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/page", + expect_code=400, + ) + + def _request_token( + self, + email: str, + client_secret: str, + next_link: Optional[str] = None, + expect_code: int = 200, + ) -> str: + """Request a validation token to add an email address to a user's account + + Args: + email: The email address to validate + client_secret: A secret string + next_link: A link to redirect the user to after validation + expect_code: Expected return code of the call + + Returns: + The ID of the new threepid validation session + """ + body = {"client_secret": client_secret, "email": email, "send_attempt": 1} + if next_link: + body["next_link"] = next_link + + channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + body, + ) + + if channel.code != expect_code: + raise HttpResponseException( + channel.code, + channel.result["reason"], + channel.result["body"], + ) + + return channel.json_body.get("sid") + + def _request_token_invalid_email( + self, + email, + expected_errcode, + expected_error, + client_secret="foobar", + ): + channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(expected_errcode, channel.json_body["errcode"]) + self.assertEqual(expected_error, channel.json_body["error"]) + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + channel = self.make_request("GET", path, shorthand=False) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) + + def _add_email(self, request_email, expected_email): + """Test adding an email to profile""" + previous_email_attempts = len(self.email_attempts) + + client_secret = "foobar" + session_id = self._request_token(request_email, client_secret) + + self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) + link = self._get_link_from_email() + + self._validate_token(link) + + channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + + threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} + self.assertIn(expected_email, threepids) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py new file mode 100644 index 0000000000..e2fcbdc63a --- /dev/null +++ b/tests/rest/client/test_auth.py @@ -0,0 +1,717 @@ +# Copyright 2018 New Vector +# Copyright 2020-2021 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +from twisted.internet.defer import succeed + +import synapse.rest.admin +from synapse.api.constants import LoginType +from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker +from synapse.rest.client import account, auth, devices, login, register +from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.types import JsonDict, UserID + +from tests import unittest +from tests.handlers.test_oidc import HAS_OIDC +from tests.rest.client.utils import TEST_OIDC_CONFIG +from tests.server import FakeChannel +from tests.unittest import override_config, skip_unless + + +class DummyRecaptchaChecker(UserInteractiveAuthChecker): + def __init__(self, hs): + super().__init__(hs) + self.recaptcha_attempts = [] + + def check_auth(self, authdict, clientip): + self.recaptcha_attempts.append((authdict, clientip)) + return succeed(True) + + +class FallbackAuthTests(unittest.HomeserverTestCase): + + servlets = [ + auth.register_servlets, + register.register_servlets, + ] + hijack_auth = False + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + + config["enable_registration_captcha"] = True + config["recaptcha_public_key"] = "brokencake" + config["registrations_require_3pid"] = [] + + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor, clock, hs): + self.recaptcha_checker = DummyRecaptchaChecker(hs) + auth_handler = hs.get_auth_handler() + auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker + + def register(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Make a register request.""" + channel = self.make_request("POST", "register", body) + + self.assertEqual(channel.code, expected_response) + return channel + + def recaptcha( + self, + session: str, + expected_post_response: int, + post_session: Optional[str] = None, + ) -> None: + """Get and respond to a fallback recaptcha. Returns the second request.""" + if post_session is None: + post_session = session + + channel = self.make_request( + "GET", "auth/m.login.recaptcha/fallback/web?session=" + session + ) + self.assertEqual(channel.code, 200) + + channel = self.make_request( + "POST", + "auth/m.login.recaptcha/fallback/web?session=" + + post_session + + "&g-recaptcha-response=a", + ) + self.assertEqual(channel.code, expected_post_response) + + # The recaptcha handler is called with the response given + attempts = self.recaptcha_checker.recaptcha_attempts + self.assertEqual(len(attempts), 1) + self.assertEqual(attempts[0][0]["response"], "a") + + def test_fallback_captcha(self): + """Ensure that fallback auth via a captcha works.""" + # Returns a 401 as per the spec + channel = self.register( + 401, + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) + + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + # Complete the recaptcha step. + self.recaptcha(session, 200) + + # also complete the dummy auth + self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) + + # Now we should have fulfilled a complete auth flow, including + # the recaptcha fallback step, we can then send a + # request to the register API with the session in the authdict. + channel = self.register(200, {"auth": {"session": session}}) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") + + def test_complete_operation_unknown_session(self): + """ + Attempting to mark an invalid session as complete should error. + """ + # Make the initial request to register. (Later on a different password + # will be used.) + # Returns a 401 as per the spec + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"} + ) + + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + # Attempt to complete the recaptcha step with an unknown session. + # This results in an error. + self.recaptcha(session, 400, session + "unknown") + + +class UIAuthTests(unittest.HomeserverTestCase): + servlets = [ + auth.register_servlets, + devices.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + ] + + def default_config(self): + config = super().default_config() + + # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns + # False, so synapse will see the requested uri as http://..., so using http in + # the public_baseurl stops Synapse trying to redirect to https. + config["public_baseurl"] = "http://synapse.test" + + if HAS_OIDC: + # we enable OIDC as a way of testing SSO flows + oidc_config = {} + oidc_config.update(TEST_OIDC_CONFIG) + oidc_config["allow_existing_users"] = True + config["oidc_config"] = oidc_config + + return config + + def create_resource_dict(self): + resource_dict = super().create_resource_dict() + resource_dict.update(build_synapse_client_resource_tree(self.hs)) + return resource_dict + + def prepare(self, reactor, clock, hs): + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + self.device_id = "dev1" + self.user_tok = self.login("test", self.user_pass, self.device_id) + + def delete_device( + self, + access_token: str, + device: str, + expected_response: int, + body: Union[bytes, JsonDict] = b"", + ) -> FakeChannel: + """Delete an individual device.""" + channel = self.make_request( + "DELETE", + "devices/" + device, + body, + access_token=access_token, + ) + + # Ensure the response is sane. + self.assertEqual(channel.code, expected_response) + + return channel + + def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Delete 1 or more devices.""" + # Note that this uses the delete_devices endpoint so that we can modify + # the payload half-way through some tests. + channel = self.make_request( + "POST", + "delete_devices", + body, + access_token=self.user_tok, + ) + + # Ensure the response is sane. + self.assertEqual(channel.code, expected_response) + + return channel + + def test_ui_auth(self): + """ + Test user interactive authentication outside of registration. + """ + # Attempt to delete this device. + # Returns a 401 as per the spec + channel = self.delete_device(self.user_tok, self.device_id, 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + self.device_id, + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_grandfathered_identifier(self): + """Check behaviour without "identifier" dict + + Synapse used to require clients to submit a "user" field for m.login.password + UIA - check that still works. + """ + + channel = self.delete_device(self.user_tok, self.device_id, 401) + session = channel.json_body["session"] + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + self.device_id, + 200, + { + "auth": { + "type": "m.login.password", + "user": self.user, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_can_change_body(self): + """ + The client dict can be modified during the user interactive authentication session. + + Note that it is not spec compliant to modify the client dict during a + user interactive authentication session, but many clients currently do. + + When Synapse is updated to be spec compliant, the call to re-use the + session ID should be rejected. + """ + # Create a second login. + self.login("test", self.user_pass, "dev2") + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_devices(401, {"devices": [self.device_id]}) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. + self.delete_devices( + 200, + { + "devices": ["dev2"], + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_cannot_change_uri(self): + """ + The initial requested URI cannot be modified during the user interactive authentication session. + """ + # Create a second login. + self.login("test", self.user_pass, "dev2") + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_device(self.user_tok, self.device_id, 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. This results in an error. + # + # This makes use of the fact that the device ID is embedded into the URL. + self.delete_device( + self.user_tok, + "dev2", + 403, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + @unittest.override_config({"ui_auth": {"session_timeout": "5s"}}) + def test_can_reuse_session(self): + """ + The session can be reused if configured. + + Compare to test_cannot_change_uri. + """ + # Create a second and third login. + self.login("test", self.user_pass, "dev2") + self.login("test", self.user_pass, "dev3") + + # Attempt to delete a device. This works since the user just logged in. + self.delete_device(self.user_tok, "dev2", 200) + + # Move the clock forward past the validation timeout. + self.reactor.advance(6) + + # Deleting another devices throws the user into UI auth. + channel = self.delete_device(self.user_tok, "dev3", 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + "dev3", + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + # Make another request, but try to delete the first device. This works + # due to re-using the previous session. + # + # Note that *no auth* information is provided, not even a session iD! + self.delete_device(self.user_tok, self.device_id, 200) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_ui_auth_via_sso(self): + """Test a successful UI Auth flow via SSO + + This includes: + * hitting the UIA SSO redirect endpoint + * checking it serves a confirmation page which links to the OIDC provider + * calling back to the synapse oidc callback + * checking that the original operation succeeds + """ + + # log the user in + remote_user_id = UserID.from_string(self.user).localpart + login_resp = self.helper.login_via_oidc(remote_user_id) + self.assertEqual(login_resp["user_id"], self.user) + + # initiate a UI Auth process by attempting to delete the device + channel = self.delete_device(self.user_tok, self.device_id, 401) + + # check that SSO is offered + flows = channel.json_body["flows"] + self.assertIn({"stages": ["m.login.sso"]}, flows) + + # run the UIA-via-SSO flow + session_id = channel.json_body["session"] + channel = self.helper.auth_via_oidc( + {"sub": remote_user_id}, ui_auth_session_id=session_id + ) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + + # and now the delete request should succeed. + self.delete_device( + self.user_tok, + self.device_id, + 200, + body={"auth": {"session": session_id}}, + ) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_does_not_offer_password_for_sso_user(self): + login_resp = self.helper.login_via_oidc("username") + user_tok = login_resp["access_token"] + device_id = login_resp["device_id"] + + # now call the device deletion API: we should get the option to auth with SSO + # and not password. + channel = self.delete_device(user_tok, device_id, 401) + + flows = channel.json_body["flows"] + self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) + + def test_does_not_offer_sso_for_password_user(self): + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + self.assertEqual(flows, [{"stages": ["m.login.password"]}]) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_offers_both_flows_for_upgraded_user(self): + """A user that had a password and then logged in with SSO should get both flows""" + login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + self.assertEqual(login_resp["user_id"], self.user) + + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + # we have no particular expectations of ordering here + self.assertIn({"stages": ["m.login.password"]}, flows) + self.assertIn({"stages": ["m.login.sso"]}, flows) + self.assertEqual(len(flows), 2) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_ui_auth_fails_for_incorrect_sso_user(self): + """If the user tries to authenticate with the wrong SSO user, they get an error""" + # log the user in + login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + self.assertEqual(login_resp["user_id"], self.user) + + # start a UI Auth flow by attempting to delete a device + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + self.assertIn({"stages": ["m.login.sso"]}, flows) + session_id = channel.json_body["session"] + + # do the OIDC auth, but auth as the wrong user + channel = self.helper.auth_via_oidc( + {"sub": "wrong_user"}, ui_auth_session_id=session_id + ) + + # that should return a failure message + self.assertSubstring("We were unable to validate", channel.text_body) + + # ... and the delete op should now fail with a 403 + self.delete_device( + self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} + ) + + +class RefreshAuthTests(unittest.HomeserverTestCase): + servlets = [ + auth.register_servlets, + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + ] + hijack_auth = False + + def prepare(self, reactor, clock, hs): + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + + def test_login_issue_refresh_token(self): + """ + A login response should include a refresh_token only if asked. + """ + # Test login + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + + login_without_refresh = self.make_request( + "POST", "/_matrix/client/r0/login", body + ) + self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) + self.assertNotIn("refresh_token", login_without_refresh.json_body) + + login_with_refresh = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) + self.assertIn("refresh_token", login_with_refresh.json_body) + self.assertIn("expires_in_ms", login_with_refresh.json_body) + + def test_register_issue_refresh_token(self): + """ + A register response should include a refresh_token only if asked. + """ + register_without_refresh = self.make_request( + "POST", + "/_matrix/client/r0/register", + { + "username": "test2", + "password": self.user_pass, + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual( + register_without_refresh.code, 200, register_without_refresh.result + ) + self.assertNotIn("refresh_token", register_without_refresh.json_body) + + register_with_refresh = self.make_request( + "POST", + "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true", + { + "username": "test3", + "password": self.user_pass, + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) + self.assertIn("refresh_token", register_with_refresh.json_body) + self.assertIn("expires_in_ms", register_with_refresh.json_body) + + def test_token_refresh(self): + """ + A refresh token can be used to issue a new access token. + """ + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + + refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertIn("access_token", refresh_response.json_body) + self.assertIn("refresh_token", refresh_response.json_body) + self.assertIn("expires_in_ms", refresh_response.json_body) + + # The access and refresh tokens should be different from the original ones after refresh + self.assertNotEqual( + login_response.json_body["access_token"], + refresh_response.json_body["access_token"], + ) + self.assertNotEqual( + login_response.json_body["refresh_token"], + refresh_response.json_body["refresh_token"], + ) + + @override_config({"access_token_lifetime": "1m"}) + def test_refresh_token_expiration(self): + """ + The access token should have some time as specified in the config. + """ + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + self.assertApproximates( + login_response.json_body["expires_in_ms"], 60 * 1000, 100 + ) + + refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertApproximates( + refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 + ) + + def test_refresh_token_invalidation(self): + """Refresh tokens are invalidated after first use of the next token. + + A refresh token is considered invalid if: + - it was already used at least once + - and either + - the next access token was used + - the next refresh token was used + + The chain of tokens goes like this: + + login -|-> first_refresh -> third_refresh (fails) + |-> second_refresh -> fifth_refresh + |-> fourth_refresh (fails) + """ + + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + + # This first refresh should work properly + first_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + first_refresh_response.code, 200, first_refresh_response.result + ) + + # This one as well, since the token in the first one was never used + second_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + second_refresh_response.code, 200, second_refresh_response.result + ) + + # This one should not, since the token from the first refresh is not valid anymore + third_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": first_refresh_response.json_body["refresh_token"]}, + ) + self.assertEqual( + third_refresh_response.code, 401, third_refresh_response.result + ) + + # The associated access token should also be invalid + whoami_response = self.make_request( + "GET", + "/_matrix/client/r0/account/whoami", + access_token=first_refresh_response.json_body["access_token"], + ) + self.assertEqual(whoami_response.code, 401, whoami_response.result) + + # But all other tokens should work (they will expire after some time) + for access_token in [ + second_refresh_response.json_body["access_token"], + login_response.json_body["access_token"], + ]: + whoami_response = self.make_request( + "GET", "/_matrix/client/r0/account/whoami", access_token=access_token + ) + self.assertEqual(whoami_response.code, 200, whoami_response.result) + + # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail + fourth_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + fourth_refresh_response.code, 403, fourth_refresh_response.result + ) + + # But refreshing from the last valid refresh token still works + fifth_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": second_refresh_response.json_body["refresh_token"]}, + ) + self.assertEqual( + fifth_refresh_response.code, 200, fifth_refresh_response.result + ) diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py new file mode 100644 index 0000000000..ac31e5ceaf --- /dev/null +++ b/tests/rest/client/test_capabilities.py @@ -0,0 +1,212 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse.rest.admin +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.rest.client import capabilities, login + +from tests import unittest +from tests.unittest import override_config + + +class CapabilitiesTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + capabilities.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.url = b"/_matrix/client/r0/capabilities" + hs = self.setup_test_homeserver() + self.config = hs.config + self.auth_handler = hs.get_auth_handler() + return hs + + def prepare(self, reactor, clock, hs): + self.localpart = "user" + self.password = "pass" + self.user = self.register_user(self.localpart, self.password) + + def test_check_auth_required(self): + channel = self.make_request("GET", self.url) + + self.assertEqual(channel.code, 401) + + def test_get_room_version_capabilities(self): + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + for room_version in capabilities["m.room_versions"]["available"].keys(): + self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) + + self.assertEqual( + self.config.default_room_version.identifier, + capabilities["m.room_versions"]["default"], + ) + + def test_get_change_password_capabilities_password_login(self): + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + + @override_config({"password_config": {"localdb_enabled": False}}) + def test_get_change_password_capabilities_localdb_disabled(self): + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["m.change_password"]["enabled"]) + + @override_config({"password_config": {"enabled": False}}) + def test_get_change_password_capabilities_password_disabled(self): + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["m.change_password"]["enabled"]) + + def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self): + """Test that per default msc3283 is disabled server returns `m.change_password`.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities) + self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities) + self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities) + + @override_config({"experimental_features": {"msc3283_enabled": True}}) + def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self): + """Test if msc3283 is enabled server returns capabilities.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + + @override_config( + { + "enable_set_displayname": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_get_set_displayname_capabilities_displayname_disabled(self): + """Test if set displayname is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) + + @override_config( + { + "enable_set_avatar_url": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): + """Test if set avatar_url is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) + + @override_config( + { + "enable_3pid_changes": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_change_3pid_capabilities_3pid_disabled(self): + """Test if change 3pid is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + + def test_get_does_not_include_msc3244_fields_by_default(self): + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertNotIn( + "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] + ) + + @override_config({"experimental_features": {"msc3244_enabled": True}}) + def test_get_does_include_msc3244_fields_when_enabled(self): + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + for details in capabilities["m.room_versions"][ + "org.matrix.msc3244.room_capabilities" + ].values(): + if details["preferred"] is not None: + self.assertTrue( + details["preferred"] in KNOWN_ROOM_VERSIONS, + str(details["preferred"]), + ) + + self.assertGreater(len(details["support"]), 0) + for room_version in details["support"]: + self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version)) diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py new file mode 100644 index 0000000000..d2181ea907 --- /dev/null +++ b/tests/rest/client/test_directory.py @@ -0,0 +1,154 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.rest import admin +from synapse.rest.client import directory, login, room +from synapse.types import RoomAlias +from synapse.util.stringutils import random_string + +from tests import unittest +from tests.unittest import override_config + + +class DirectoryTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["require_membership_for_aliases"] = True + + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.user = self.register_user("user", "test") + self.user_tok = self.login("user", "test") + + def test_state_event_not_in_room(self): + self.ensure_user_left_room() + self.set_alias_via_state_event(403) + + def test_directory_endpoint_not_in_room(self): + self.ensure_user_left_room() + self.set_alias_via_directory(403) + + def test_state_event_in_room_too_long(self): + self.ensure_user_joined_room() + self.set_alias_via_state_event(400, alias_length=256) + + def test_directory_in_room_too_long(self): + self.ensure_user_joined_room() + self.set_alias_via_directory(400, alias_length=256) + + @override_config({"default_room_version": 5}) + def test_state_event_user_in_v5_room(self): + """Test that a regular user can add alias events before room v6""" + self.ensure_user_joined_room() + self.set_alias_via_state_event(200) + + @override_config({"default_room_version": 6}) + def test_state_event_v6_room(self): + """Test that a regular user can *not* add alias events from room v6""" + self.ensure_user_joined_room() + self.set_alias_via_state_event(403) + + def test_directory_in_room(self): + self.ensure_user_joined_room() + self.set_alias_via_directory(200) + + def test_room_creation_too_long(self): + url = "/_matrix/client/r0/createRoom" + + # We use deliberately a localpart under the length threshold so + # that we can make sure that the check is done on the whole alias. + data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} + request_data = json.dumps(data) + channel = self.make_request( + "POST", url, request_data, access_token=self.user_tok + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_room_creation(self): + url = "/_matrix/client/r0/createRoom" + + # Check with an alias of allowed length. There should already be + # a test that ensures it works in test_register.py, but let's be + # as cautious as possible here. + data = {"room_alias_name": random_string(5)} + request_data = json.dumps(data) + channel = self.make_request( + "POST", url, request_data, access_token=self.user_tok + ) + self.assertEqual(channel.code, 200, channel.result) + + def set_alias_via_state_event(self, expected_code, alias_length=5): + url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( + self.room_id, + self.hs.hostname, + ) + + data = {"aliases": [self.random_alias(alias_length)]} + request_data = json.dumps(data) + + channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.assertEqual(channel.code, expected_code, channel.result) + + def set_alias_via_directory(self, expected_code, alias_length=5): + url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.assertEqual(channel.code, expected_code, channel.result) + + def random_alias(self, length): + return RoomAlias(random_string(length), self.hs.hostname).to_string() + + def ensure_user_left_room(self): + self.ensure_membership("leave") + + def ensure_user_joined_room(self): + self.ensure_membership("join") + + def ensure_membership(self, membership): + try: + if membership == "leave": + self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) + if membership == "join": + self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok) + except AssertionError: + # We don't care whether the leave request didn't return a 200 (e.g. + # if the user isn't already in the room), because we only want to + # make sure the user isn't in the room. + pass diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py new file mode 100644 index 0000000000..a90294003e --- /dev/null +++ b/tests/rest/client/test_events.py @@ -0,0 +1,156 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Tests REST events for /events paths.""" + +from unittest.mock import Mock + +import synapse.rest.admin +from synapse.rest.client import events, login, room + +from tests import unittest + + +class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): + """Tests event streaming (GET /events).""" + + servlets = [ + events.register_servlets, + room.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + config["enable_registration_captcha"] = False + config["enable_registration"] = True + config["auto_join_rooms"] = [] + + hs = self.setup_test_homeserver(config=config) + + hs.get_federation_handler = Mock() + + return hs + + def prepare(self, reactor, clock, hs): + + # register an account + self.user_id = self.register_user("sid1", "pass") + self.token = self.login(self.user_id, "pass") + + # register a 2nd account + self.other_user = self.register_user("other2", "pass") + self.other_token = self.login(self.other_user, "pass") + + def test_stream_basic_permissions(self): + # invalid token, expect 401 + # note: this is in violation of the original v1 spec, which expected + # 403. However, since the v1 spec no longer exists and the v1 + # implementation is now part of the r0 implementation, the newer + # behaviour is used instead to be consistent with the r0 spec. + # see issue #2602 + channel = self.make_request( + "GET", "/events?access_token=%s" % ("invalid" + self.token,) + ) + self.assertEquals(channel.code, 401, msg=channel.result) + + # valid token, expect content + channel = self.make_request( + "GET", "/events?access_token=%s&timeout=0" % (self.token,) + ) + self.assertEquals(channel.code, 200, msg=channel.result) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("start" in channel.json_body) + self.assertTrue("end" in channel.json_body) + + def test_stream_room_permissions(self): + room_id = self.helper.create_room_as(self.other_user, tok=self.other_token) + self.helper.send(room_id, tok=self.other_token) + + # invited to room (expect no content for room) + self.helper.invite( + room_id, src=self.other_user, targ=self.user_id, tok=self.other_token + ) + + # valid token, expect content + channel = self.make_request( + "GET", "/events?access_token=%s&timeout=0" % (self.token,) + ) + self.assertEquals(channel.code, 200, msg=channel.result) + + # We may get a presence event for ourselves down + self.assertEquals( + 0, + len( + [ + c + for c in channel.json_body["chunk"] + if not ( + c.get("type") == "m.presence" + and c["content"].get("user_id") == self.user_id + ) + ] + ), + ) + + # joined room (expect all content for room) + self.helper.join(room=room_id, user=self.user_id, tok=self.token) + + # left to room (expect no content for room) + + def TODO_test_stream_items(self): + # new user, no content + + # join room, expect 1 item (join) + + # send message, expect 2 items (join,send) + + # set topic, expect 3 items (join,send,topic) + + # someone else join room, expect 4 (join,send,topic,join) + + # someone else send message, expect 5 (join,send.topic,join,send) + + # someone else set topic, expect 6 (join,send,topic,join,send,topic) + pass + + +class GetEventsTestCase(unittest.HomeserverTestCase): + servlets = [ + events.register_servlets, + room.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def prepare(self, hs, reactor, clock): + + # register an account + self.user_id = self.register_user("sid1", "pass") + self.token = self.login(self.user_id, "pass") + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + def test_get_event_via_events(self): + resp = self.helper.send(self.room_id, tok=self.token) + event_id = resp["event_id"] + + channel = self.make_request( + "GET", + "/events/" + event_id, + access_token=self.token, + ) + self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py new file mode 100644 index 0000000000..475c6bed3d --- /dev/null +++ b/tests/rest/client/test_filter.py @@ -0,0 +1,111 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import Codes +from synapse.rest.client import filter + +from tests import unittest + +PATH_PREFIX = "/_matrix/client/v2_alpha" + + +class FilterTestCase(unittest.HomeserverTestCase): + + user_id = "@apple:test" + hijack_auth = True + EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} + EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' + servlets = [filter.register_servlets] + + def prepare(self, reactor, clock, hs): + self.filtering = hs.get_filtering() + self.store = hs.get_datastore() + + def test_add_filter(self): + channel = self.make_request( + "POST", + "/_matrix/client/r0/user/%s/filter" % (self.user_id), + self.EXAMPLE_FILTER_JSON, + ) + + self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.json_body, {"filter_id": "0"}) + filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) + self.pump() + self.assertEquals(filter.result, self.EXAMPLE_FILTER) + + def test_add_filter_for_other_user(self): + channel = self.make_request( + "POST", + "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), + self.EXAMPLE_FILTER_JSON, + ) + + self.assertEqual(channel.result["code"], b"403") + self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + + def test_add_filter_non_local_user(self): + _is_mine = self.hs.is_mine + self.hs.is_mine = lambda target_user: False + channel = self.make_request( + "POST", + "/_matrix/client/r0/user/%s/filter" % (self.user_id), + self.EXAMPLE_FILTER_JSON, + ) + + self.hs.is_mine = _is_mine + self.assertEqual(channel.result["code"], b"403") + self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + + def test_get_filter(self): + filter_id = defer.ensureDeferred( + self.filtering.add_user_filter( + user_localpart="apple", user_filter=self.EXAMPLE_FILTER + ) + ) + self.reactor.advance(1) + filter_id = filter_id.result + channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) + ) + + self.assertEqual(channel.result["code"], b"200") + self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) + + def test_get_filter_non_existant(self): + channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) + ) + + self.assertEqual(channel.result["code"], b"404") + self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) + + # Currently invalid params do not have an appropriate errcode + # in errors.py + def test_get_filter_invalid_id(self): + channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) + ) + + self.assertEqual(channel.result["code"], b"400") + + # No ID also returns an invalid_id error + def test_get_filter_no_id(self): + channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) + ) + + self.assertEqual(channel.result["code"], b"400") diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py new file mode 100644 index 0000000000..d7fa635eae --- /dev/null +++ b/tests/rest/client/test_keys.py @@ -0,0 +1,91 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from http import HTTPStatus + +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client import keys, login + +from tests import unittest + + +class KeyQueryTestCase(unittest.HomeserverTestCase): + servlets = [ + keys.register_servlets, + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def test_rejects_device_id_ice_key_outside_of_list(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: "device_id1", + }, + }, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_rejects_device_key_given_as_map_to_bool(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: { + "device_id1": True, + }, + }, + }, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_requires_device_key(self): + """`device_keys` is required. We should complain if it's missing.""" + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + {}, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py new file mode 100644 index 0000000000..5b2243fe52 --- /dev/null +++ b/tests/rest/client/test_login.py @@ -0,0 +1,1345 @@ +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import urllib.parse +from typing import Any, Dict, List, Optional, Union +from unittest.mock import Mock +from urllib.parse import urlencode + +import pymacaroons + +from twisted.web.resource import Resource + +import synapse.rest.admin +from synapse.appservice import ApplicationService +from synapse.rest.client import devices, login, logout, register +from synapse.rest.client.account import WhoamiRestServlet +from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.types import create_requester + +from tests import unittest +from tests.handlers.test_oidc import HAS_OIDC +from tests.handlers.test_saml import has_saml2 +from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.test_utils.html_parsers import TestHtmlParser +from tests.unittest import HomeserverTestCase, override_config, skip_unless + +try: + import jwt + + HAS_JWT = True +except ImportError: + HAS_JWT = False + + +# synapse server name: used to populate public_baseurl in some tests +SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" + +# public_baseurl for some tests. It uses an http:// scheme because +# FakeChannel.isSecure() returns False, so synapse will see the requested uri as +# http://..., so using http in the public_baseurl stops Synapse trying to redirect to +# https://.... +BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) + +# CAS server used in some tests +CAS_SERVER = "https://fake.test" + +# just enough to tell pysaml2 where to redirect to +SAML_SERVER = "https://test.saml.server/idp/sso" +TEST_SAML_METADATA = """ + + + + + +""" % { + "SAML_SERVER": SAML_SERVER, +} + +LOGIN_URL = b"/_matrix/client/r0/login" +TEST_URL = b"/_matrix/client/r0/account/whoami" + +# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + +TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' + +# the query params in TEST_CLIENT_REDIRECT_URL +EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] + +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + + +class LoginRestServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + logout.register_servlets, + devices.register_servlets, + lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.enable_registration = True + self.hs.config.registrations_require_3pid = [] + self.hs.config.auto_join_rooms = [] + self.hs.config.enable_registration_captcha = False + + return self.hs + + @override_config( + { + "rc_login": { + "address": {"per_second": 0.17, "burst_count": 5}, + # Prevent the account login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "account": {"per_second": 10000, "burst_count": 10000}, + } + } + ) + def test_POST_ratelimiting_per_address(self): + # Create different users so we're sure not to be bothered by the per-user + # ratelimiter. + for i in range(0, 6): + self.register_user("kermit" + str(i), "monkey") + + for i in range(0, 6): + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower + # than 1min. + self.assertTrue(retry_after_ms < 6000) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config( + { + "rc_login": { + "account": {"per_second": 0.17, "burst_count": 5}, + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + } + } + ) + def test_POST_ratelimiting_per_account(self): + self.register_user("kermit", "monkey") + + for i in range(0, 6): + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower + # than 1min. + self.assertTrue(retry_after_ms < 6000) + + self.reactor.advance(retry_after_ms / 1000.0) + + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config( + { + "rc_login": { + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + "failed_attempts": {"per_second": 0.17, "burst_count": 5}, + } + } + ) + def test_POST_ratelimiting_per_account_failed_attempts(self): + self.register_user("kermit", "monkey") + + for i in range(0, 6): + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "notamonkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"403", channel.result) + + # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower + # than 1min. + self.assertTrue(retry_after_ms < 6000) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "notamonkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.result["code"], b"403", channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_soft_logout(self): + self.register_user("kermit", "monkey") + + # we shouldn't be able to make requests without an access token + channel = self.make_request(b"GET", TEST_URL) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") + + # log in as normal + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.code, 200, channel.result) + access_token = channel.json_body["access_token"] + device_id = channel.json_body["device_id"] + + # we should now be able to make requests with the access token + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # + # test behaviour after deleting the expired device + # + + # we now log in as a different device + access_token_2 = self.login("kermit", "monkey") + + # more requests with the expired token should still return a soft-logout + self.reactor.advance(3600) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # ... but if we delete that device, it will be a proper logout + self._delete_device(access_token_2, "kermit", "monkey", device_id) + + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], False) + + def _delete_device(self, access_token, user_id, password, device_id): + """Perform the UI-Auth to delete a device""" + channel = self.make_request( + b"DELETE", "devices/" + device_id, access_token=access_token + ) + self.assertEquals(channel.code, 401, channel.result) + # check it's a UI-Auth fail + self.assertEqual( + set(channel.json_body.keys()), + {"flows", "params", "session"}, + channel.result, + ) + + auth = { + "type": "m.login.password", + # https://github.com/matrix-org/synapse/issues/5665 + # "identifier": {"type": "m.id.user", "user": user_id}, + "user": user_id, + "password": password, + "session": channel.json_body["session"], + } + + channel = self.make_request( + b"DELETE", + "devices/" + device_id, + access_token=access_token, + content={"auth": auth}, + ) + self.assertEquals(channel.code, 200, channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard logout this session + channel = self.make_request(b"POST", "/logout", access_token=access_token) + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard log out all of the user's sessions + channel = self.make_request(b"POST", "/logout/all", access_token=access_token) + self.assertEquals(channel.result["code"], b"200", channel.result) + + +@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") +class MultiSSOTestCase(unittest.HomeserverTestCase): + """Tests for homeservers with multiple SSO providers enabled""" + + servlets = [ + login.register_servlets, + ] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + + config["public_baseurl"] = BASE_URL + + config["cas_config"] = { + "enabled": True, + "server_url": CAS_SERVER, + "service_url": "https://matrix.goodserver.com:8448", + } + + config["saml2_config"] = { + "sp_config": { + "metadata": {"inline": [TEST_SAML_METADATA]}, + # use the XMLSecurity backend to avoid relying on xmlsec1 + "crypto_backend": "XMLSecurity", + }, + } + + # default OIDC provider + config["oidc_config"] = TEST_OIDC_CONFIG + + # additional OIDC providers + config["oidc_providers"] = [ + { + "idp_id": "idp1", + "idp_name": "IDP1", + "discover": False, + "issuer": "https://issuer1", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": "https://issuer1/auth", + "token_endpoint": "https://issuer1/token", + "userinfo_endpoint": "https://issuer1/userinfo", + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.sub }}"} + }, + } + ] + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d.update(build_synapse_client_resource_tree(self.hs)) + return d + + def test_get_login_flows(self): + """GET /login should return password and SSO flows""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + expected_flow_types = [ + "m.login.cas", + "m.login.sso", + "m.login.token", + "m.login.password", + ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS] + + self.assertCountEqual( + [f["type"] for f in channel.json_body["flows"]], expected_flow_types + ) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_get_msc2858_login_flows(self): + """The SSO flow should include IdP info if MSC2858 is enabled""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + # stick the flows results in a dict by type + flow_results: Dict[str, Any] = {} + for f in channel.json_body["flows"]: + flow_type = f["type"] + self.assertNotIn( + flow_type, flow_results, "duplicate flow type %s" % (flow_type,) + ) + flow_results[flow_type] = f + + self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned") + sso_flow = flow_results.pop("m.login.sso") + # we should have a set of IdPs + self.assertCountEqual( + sso_flow["org.matrix.msc2858.identity_providers"], + [ + {"id": "cas", "name": "CAS"}, + {"id": "saml", "name": "SAML"}, + {"id": "oidc-idp1", "name": "IDP1"}, + {"id": "oidc", "name": "OIDC"}, + ], + ) + + # the rest of the flows are simple + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(flow_results.values(), expected_flows) + + def test_multi_sso_redirect(self): + """/login/sso/redirect should redirect to an identity picker""" + # first hit the redirect url, which should redirect to our idp picker + channel = self._make_sso_redirect_request(False, None) + self.assertEqual(channel.code, 302, channel.result) + uri = channel.headers.getRawHeaders("Location")[0] + + # hitting that picker should give us some HTML + channel = self.make_request("GET", uri) + self.assertEqual(channel.code, 200, channel.result) + + # parse the form to check it has fields assumed elsewhere in this class + html = channel.result["body"].decode("utf-8") + p = TestHtmlParser() + p.feed(html) + p.close() + + # there should be a link for each href + returned_idps: List[str] = [] + for link in p.links: + path, query = link.split("?", 1) + self.assertEqual(path, "pick_idp") + params = urllib.parse.parse_qs(query) + self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL]) + returned_idps.append(params["idp"][0]) + + self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) + + def test_multi_sso_redirect_to_cas(self): + """If CAS is chosen, should redirect to the CAS server""" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=cas", + shorthand=False, + ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + cas_uri = location_headers[0] + cas_uri_path, cas_uri_query = cas_uri.split("?", 1) + + # it should redirect us to the login page of the cas server + self.assertEqual(cas_uri_path, CAS_SERVER + "/login") + + # check that the redirectUrl is correctly encoded in the service param - ie, the + # place that CAS will redirect to + cas_uri_params = urllib.parse.parse_qs(cas_uri_query) + service_uri = cas_uri_params["service"][0] + _, service_uri_query = service_uri.split("?", 1) + service_uri_params = urllib.parse.parse_qs(service_uri_query) + self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) + + def test_multi_sso_redirect_to_saml(self): + """If SAML is chosen, should redirect to the SAML server""" + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=saml", + ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + saml_uri = location_headers[0] + saml_uri_path, saml_uri_query = saml_uri.split("?", 1) + + # it should redirect us to the login page of the SAML server + self.assertEqual(saml_uri_path, SAML_SERVER) + + # the RelayState is used to carry the client redirect url + saml_uri_params = urllib.parse.parse_qs(saml_uri_query) + relay_state_param = saml_uri_params["RelayState"][0] + self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) + + def test_login_via_oidc(self): + """If OIDC is chosen, should redirect to the OIDC auth endpoint""" + + # pick the default OIDC provider + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=oidc", + ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + oidc_uri = location_headers[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + # ... and should have set a cookie including the redirect url + cookie_headers = channel.headers.getRawHeaders("Set-Cookie") + assert cookie_headers + cookies: Dict[str, str] = {} + for h in cookie_headers: + key, value = h.split(";")[0].split("=", maxsplit=1) + cookies[key] = value + + oidc_session_cookie = cookies["oidc_session"] + macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) + self.assertEqual( + self._get_value_from_macaroon(macaroon, "client_redirect_url"), + TEST_CLIENT_REDIRECT_URL, + ) + + channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + content_type_headers = channel.headers.getRawHeaders("Content-Type") + assert content_type_headers + self.assertTrue(content_type_headers[-1].startswith("text/html")) + p = TestHtmlParser() + p.feed(channel.text_body) + p.close() + + # ... which should contain our redirect link + self.assertEqual(len(p.links), 1) + path, query = p.links[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + login_token = params[2][1] + chan = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@user1:test") + + def test_multi_sso_redirect_to_unknown(self): + """An unknown IdP should cause a 400""" + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_client_idp_redirect_to_unknown(self): + """If the client tries to pick an unknown IdP, return a 404""" + channel = self._make_sso_redirect_request(False, "xxx") + self.assertEqual(channel.code, 404, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + def test_client_idp_redirect_to_oidc(self): + """If the client pick a known IdP, redirect to it""" + channel = self._make_sso_redirect_request(False, "oidc") + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_msc2858_redirect_to_oidc(self): + """Test the unstable API""" + channel = self._make_sso_redirect_request(True, "oidc") + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400""" + channel = self._make_sso_redirect_request(True, "oidc") + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + def _make_sso_redirect_request( + self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None + ): + """Send a request to /_matrix/client/r0/login/sso/redirect + + ... or the unstable equivalent + + ... possibly specifying an IDP provider + """ + endpoint = ( + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect" + if unstable_endpoint + else "/_matrix/client/r0/login/sso/redirect" + ) + if idp_prov is not None: + endpoint += "/" + idp_prov + endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + return self.make_request( + "GET", + endpoint, + custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)], + ) + + @staticmethod + def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: + prefix = key + " = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(prefix): + return caveat.caveat_id[len(prefix) :] + raise ValueError("No %s caveat in macaroon" % (key,)) + + +class CASTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.base_url = "https://matrix.goodserver.com/" + self.redirect_path = "_synapse/client/login/sso/redirect/confirm" + + config = self.default_config() + config["public_baseurl"] = ( + config.get("public_baseurl") or "https://matrix.goodserver.com:8448" + ) + config["cas_config"] = { + "enabled": True, + "server_url": CAS_SERVER, + } + + cas_user_id = "username" + self.user_id = "@%s:test" % cas_user_id + + async def get_raw(uri, args): + """Return an example response payload from a call to the `/proxyValidate` + endpoint of a CAS server, copied from + https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 + + This needs to be returned by an async function (as opposed to set as the + mock's return value) because the corresponding Synapse code awaits on it. + """ + return ( + """ + + + %s + PGTIOU-84678-8a9d... + + https://proxy2/pgtUrl + https://proxy1/pgtUrl + + + + """ + % cas_user_id + ).encode("utf-8") + + mocked_http_client = Mock(spec=["get_raw"]) + mocked_http_client.get_raw.side_effect = get_raw + + self.hs = self.setup_test_homeserver( + config=config, + proxied_http_client=mocked_http_client, + ) + + return self.hs + + def prepare(self, reactor, clock, hs): + self.deactivate_account_handler = hs.get_deactivate_account_handler() + + def test_cas_redirect_confirm(self): + """Tests that the SSO login flow serves a confirmation page before redirecting a + user to the redirect URL. + """ + base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl" + redirect_url = "https://dodgy-site.com/" + + url_parts = list(urllib.parse.urlparse(base_url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) + query.update({"redirectUrl": redirect_url}) + query.update({"ticket": "ticket"}) + url_parts[4] = urllib.parse.urlencode(query) + cas_ticket_url = urllib.parse.urlunparse(url_parts) + + # Get Synapse to call the fake CAS and serve the template. + channel = self.make_request("GET", cas_ticket_url) + + # Test that the response is HTML. + self.assertEqual(channel.code, 200, channel.result) + content_type_header_value = "" + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type_header_value = header[1].decode("utf8") + + self.assertTrue(content_type_header_value.startswith("text/html")) + + # Test that the body isn't empty. + self.assertTrue(len(channel.result["body"]) > 0) + + # And that it contains our redirect link + self.assertIn(redirect_url, channel.result["body"].decode("UTF-8")) + + @override_config( + { + "sso": { + "client_whitelist": [ + "https://legit-site.com/", + "https://other-site.com/", + ] + } + } + ) + def test_cas_redirect_whitelisted(self): + """Tests that the SSO login flow serves a redirect to a whitelisted url""" + self._test_redirect("https://legit-site.com/") + + @override_config({"public_baseurl": "https://example.com"}) + def test_cas_redirect_login_fallback(self): + self._test_redirect("https://example.com/_matrix/static/client/login") + + def _test_redirect(self, redirect_url): + """Tests that the SSO login flow serves a redirect for the given redirect URL.""" + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + channel = self.make_request("GET", cas_ticket_url) + + self.assertEqual(channel.code, 302) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) + + @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) + def test_deactivated_user(self): + """Logging in as a deactivated account should error.""" + redirect_url = "https://legit-site.com/" + + # First login (to create the user). + self._test_redirect(redirect_url) + + # Deactivate the account. + self.get_success( + self.deactivate_account_handler.deactivate_account( + self.user_id, False, create_requester(self.user_id) + ) + ) + + # Request the CAS ticket. + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + channel = self.make_request("GET", cas_ticket_url) + + # Because the user is deactivated they are served an error template. + self.assertEqual(channel.code, 403) + self.assertIn(b"SSO account deactivated", channel.result["body"]) + + +@skip_unless(HAS_JWT, "requires jwt") +class JWTTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + jwt_secret = "secret" + jwt_algorithm = "HS256" + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.jwt_enabled = True + self.hs.config.jwt_secret = self.jwt_secret + self.hs.config.jwt_algorithm = self.jwt_algorithm + return self.hs + + def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) + if isinstance(result, bytes): + return result.decode("ascii") + return result + + def jwt_login(self, *args): + params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} + channel = self.make_request(b"POST", LOGIN_URL, params) + return channel + + def test_login_jwt_valid_registered(self): + self.register_user("kermit", "monkey") + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + def test_login_jwt_valid_unregistered(self): + channel = self.jwt_login({"sub": "frog"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@frog:test") + + def test_login_jwt_invalid_signature(self): + channel = self.jwt_login({"sub": "frog"}, "notsecret") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: Signature verification failed", + ) + + def test_login_jwt_expired(self): + channel = self.jwt_login({"sub": "frog", "exp": 864000}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Signature has expired" + ) + + def test_login_jwt_not_before(self): + now = int(time.time()) + channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: The token is not yet valid (nbf)", + ) + + def test_login_no_sub(self): + channel = self.jwt_login({"username": "root"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual(channel.json_body["error"], "Invalid JWT") + + @override_config( + { + "jwt_config": { + "jwt_enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + "issuer": "test-issuer", + } + } + ) + def test_login_iss(self): + """Test validating the issuer claim.""" + # A valid issuer. + channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + # An invalid issuer. + channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid issuer" + ) + + # Not providing an issuer. + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + 'JWT validation failed: Token is missing the "iss" claim', + ) + + def test_login_iss_no_config(self): + """Test providing an issuer claim without requiring it in the configuration.""" + channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + @override_config( + { + "jwt_config": { + "jwt_enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + "audiences": ["test-audience"], + } + } + ) + def test_login_aud(self): + """Test validating the audience claim.""" + # A valid audience. + channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + # An invalid audience. + channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid audience" + ) + + # Not providing an audience. + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + 'JWT validation failed: Token is missing the "aud" claim', + ) + + def test_login_aud_no_config(self): + """Test providing an audience without requiring it in the configuration.""" + channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid audience" + ) + + def test_login_no_token(self): + params = {"type": "org.matrix.login.jwt"} + channel = self.make_request(b"POST", LOGIN_URL, params) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") + + +# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use +# RSS256, with a public key configured in synapse as "jwt_secret", and tokens +# signed by the private key. +@skip_unless(HAS_JWT, "requires jwt") +class JWTPubKeyTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + ] + + # This key's pubkey is used as the jwt_secret setting of synapse. Valid + # tokens are signed by this and validated using the pubkey. It is generated + # with `openssl genrsa 512` (not a secure way to generate real keys, but + # good enough for tests!) + jwt_privatekey = "\n".join( + [ + "-----BEGIN RSA PRIVATE KEY-----", + "MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB", + "492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk", + "yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/", + "kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq", + "TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN", + "ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA", + "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=", + "-----END RSA PRIVATE KEY-----", + ] + ) + + # Generated with `openssl rsa -in foo.key -pubout`, with the the above + # private key placed in foo.key (jwt_privatekey). + jwt_pubkey = "\n".join( + [ + "-----BEGIN PUBLIC KEY-----", + "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7", + "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==", + "-----END PUBLIC KEY-----", + ] + ) + + # This key is used to sign tokens that shouldn't be accepted by synapse. + # Generated just like jwt_privatekey. + bad_privatekey = "\n".join( + [ + "-----BEGIN RSA PRIVATE KEY-----", + "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv", + "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L", + "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY", + "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I", + "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb", + "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0", + "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m", + "-----END RSA PRIVATE KEY-----", + ] + ) + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.jwt_enabled = True + self.hs.config.jwt_secret = self.jwt_pubkey + self.hs.config.jwt_algorithm = "RS256" + return self.hs + + def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") + if isinstance(result, bytes): + return result.decode("ascii") + return result + + def jwt_login(self, *args): + params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} + channel = self.make_request(b"POST", LOGIN_URL, params) + return channel + + def test_login_jwt_valid(self): + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + def test_login_jwt_invalid_signature(self): + channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: Signature verification failed", + ) + + +AS_USER = "as_user_alice" + + +class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + register.register_servlets, + ] + + def register_as_user(self, username): + self.make_request( + b"POST", + "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), + {"username": username}, + ) + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + + self.service = ApplicationService( + id="unique_identifier", + token="some_token", + hostname="example.com", + sender="@asbot:example.com", + namespaces={ + ApplicationService.NS_USERS: [ + {"regex": r"@as_user.*", "exclusive": False} + ], + ApplicationService.NS_ROOMS: [], + ApplicationService.NS_ALIASES: [], + }, + ) + self.another_service = ApplicationService( + id="another__identifier", + token="another_token", + hostname="example.com", + sender="@as2bot:example.com", + namespaces={ + ApplicationService.NS_USERS: [ + {"regex": r"@as2_user.*", "exclusive": False} + ], + ApplicationService.NS_ROOMS: [], + ApplicationService.NS_ALIASES: [], + }, + ) + + self.hs.get_datastore().services_cache.append(self.service) + self.hs.get_datastore().services_cache.append(self.another_service) + return self.hs + + def test_login_appservice_user(self): + """Test that an appservice user can use /login""" + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_login_appservice_user_bot(self): + """Test that the appservice bot can use /login""" + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": self.service.sender}, + } + channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_login_appservice_wrong_user(self): + """Test that non-as users cannot login with the as token""" + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": "fibble_wibble"}, + } + channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.assertEquals(channel.result["code"], b"403", channel.result) + + def test_login_appservice_wrong_as(self): + """Test that as users cannot login with wrong as token""" + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.another_service.token + ) + + self.assertEquals(channel.result["code"], b"403", channel.result) + + def test_login_appservice_no_token(self): + """Test that users must provide a token when using the appservice + login method + """ + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.result["code"], b"401", channel.result) + + +@skip_unless(HAS_OIDC, "requires OIDC") +class UsernamePickerTestCase(HomeserverTestCase): + """Tests for the username picker flow of SSO login""" + + servlets = [login.register_servlets] + + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + + config["oidc_config"] = {} + config["oidc_config"].update(TEST_OIDC_CONFIG) + config["oidc_config"]["user_mapping_provider"] = { + "config": {"display_name_template": "{{ user.displayname }}"} + } + + # whitelist this client URI so we redirect straight to it rather than + # serving a confirmation page + config["sso"] = {"client_whitelist": ["https://x"]} + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d.update(build_synapse_client_resource_tree(self.hs)) + return d + + def test_username_picker(self): + """Test the happy path of a username picker flow.""" + + # do the start of the login flow + channel = self.helper.auth_via_oidc( + {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL + ) + + # that should redirect to the username picker + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + picker_url = location_headers[0] + self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") + + # ... with a username_mapping_session cookie + cookies: Dict[str, str] = {} + channel.extract_cookies(cookies) + self.assertIn("username_mapping_session", cookies) + session_id = cookies["username_mapping_session"] + + # introspect the sso handler a bit to check that the username mapping session + # looks ok. + username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions + self.assertIn( + session_id, + username_mapping_sessions, + "session id not found in map", + ) + session = username_mapping_sessions[session_id] + self.assertEqual(session.remote_user_id, "tester") + self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) + + # the expiry time should be about 15 minutes away + expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) + self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + + # Now, submit a username to the username picker, which should serve a redirect + # to the completion page + content = urlencode({b"username": b"bobby"}).encode("utf8") + chan = self.make_request( + "POST", + path=picker_url, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", "username_mapping_session=" + session_id), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + assert location_headers + + # send a request to the completion page, which should 302 to the client redirectUrl + chan = self.make_request( + "GET", + path=location_headers[0], + custom_headers=[("Cookie", "username_mapping_session=" + session_id)], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + assert location_headers + + # ensure that the returned location matches the requested redirect URL + path, query = location_headers[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # fish the login token out of the returned redirect uri + login_token = params[2][1] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py new file mode 100644 index 0000000000..3cf5871899 --- /dev/null +++ b/tests/rest/client/test_password_policy.py @@ -0,0 +1,177 @@ +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client import account, login, password_policy, register + +from tests import unittest + + +class PasswordPolicyTestCase(unittest.HomeserverTestCase): + """Tests the password policy feature and its compliance with MSC2000. + + When validating a password, Synapse does the necessary checks in this order: + + 1. Password is long enough + 2. Password contains digit(s) + 3. Password contains symbol(s) + 4. Password contains uppercase letter(s) + 5. Password contains lowercase letter(s) + + For each test below that checks whether a password triggers the right error code, + that test provides a password good enough to pass the previous tests, but not the + one it is currently testing (nor any test that comes afterward). + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + password_policy.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.register_url = "/_matrix/client/r0/register" + self.policy = { + "enabled": True, + "minimum_length": 10, + "require_digit": True, + "require_symbol": True, + "require_lowercase": True, + "require_uppercase": True, + } + + config = self.default_config() + config["password_config"] = { + "policy": self.policy, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def test_get_policy(self): + """Tests if the /password_policy endpoint returns the configured policy.""" + + channel = self.make_request("GET", "/_matrix/client/r0/password_policy") + + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual( + channel.json_body, + { + "m.minimum_length": 10, + "m.require_digit": True, + "m.require_symbol": True, + "m.require_lowercase": True, + "m.require_uppercase": True, + }, + channel.result, + ) + + def test_password_too_short(self): + request_data = json.dumps({"username": "kermit", "password": "shorty"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_TOO_SHORT, + channel.result, + ) + + def test_password_no_digit(self): + request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_DIGIT, + channel.result, + ) + + def test_password_no_symbol(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_SYMBOL, + channel.result, + ) + + def test_password_no_uppercase(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_UPPERCASE, + channel.result, + ) + + def test_password_no_lowercase(self): + request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_LOWERCASE, + channel.result, + ) + + def test_password_compliant(self): + request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + channel = self.make_request("POST", self.register_url, request_data) + + # Getting a 401 here means the password has passed validation and the server has + # responded with a list of registration flows. + self.assertEqual(channel.code, 401, channel.result) + + def test_password_change(self): + """This doesn't test every possible use case, only that hitting /account/password + triggers the password validation code. + """ + compliant_password = "C0mpl!antpassword" + not_compliant_password = "notcompliantpassword" + + user_id = self.register_user("kermit", compliant_password) + tok = self.login("kermit", compliant_password) + + request_data = json.dumps( + { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } + ) + channel = self.make_request( + "POST", + "/_matrix/client/r0/account/password", + request_data, + access_token=tok, + ) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py new file mode 100644 index 0000000000..1d152352d1 --- /dev/null +++ b/tests/rest/client/test_presence.py @@ -0,0 +1,76 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +from twisted.internet import defer + +from synapse.handlers.presence import PresenceHandler +from synapse.rest.client import presence +from synapse.types import UserID + +from tests import unittest + + +class PresenceTestCase(unittest.HomeserverTestCase): + """Tests presence REST API.""" + + user_id = "@sid:red" + + user = UserID.from_string(user_id) + servlets = [presence.register_servlets] + + def make_homeserver(self, reactor, clock): + + presence_handler = Mock(spec=PresenceHandler) + presence_handler.set_state.return_value = defer.succeed(None) + + hs = self.setup_test_homeserver( + "red", + federation_http_client=None, + federation_client=Mock(), + presence_handler=presence_handler, + ) + + return hs + + def test_put_presence(self): + """ + PUT to the status endpoint with use_presence enabled will call + set_state on the presence handler. + """ + self.hs.config.use_presence = True + + body = {"presence": "here", "status_msg": "beep boop"} + channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + + self.assertEqual(channel.code, 200) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) + + @unittest.override_config({"use_presence": False}) + def test_put_presence_disabled(self): + """ + PUT to the status endpoint with use_presence disabled will NOT call + set_state on the presence handler. + """ + + body = {"presence": "here", "status_msg": "beep boop"} + channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + + self.assertEqual(channel.code, 200) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py new file mode 100644 index 0000000000..2860579c2e --- /dev/null +++ b/tests/rest/client/test_profile.py @@ -0,0 +1,270 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests REST events for /profile paths.""" +from synapse.rest import admin +from synapse.rest.client import login, profile, room + +from tests import unittest + + +class ProfileTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + return self.hs + + def prepare(self, reactor, clock, hs): + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + self.other = self.register_user("other", "pass", displayname="Bob") + + def test_get_displayname(self): + res = self._get_displayname() + self.assertEqual(res, "owner") + + def test_set_displayname(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res = self._get_displayname() + self.assertEqual(res, "test") + + def test_set_displayname_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": "test"}, + ) + self.assertEqual(channel.code, 401, channel.result) + + def test_set_displayname_too_long(self): + """Attempts to set a stupid displayname should get a 400""" + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": "test" * 100}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + res = self._get_displayname() + self.assertEqual(res, "owner") + + def test_get_displayname_other(self): + res = self._get_displayname(self.other) + self.assertEquals(res, "Bob") + + def test_set_displayname_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.other,), + content={"displayname": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_get_avatar_url(self): + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_set_avatar_url(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res = self._get_avatar_url() + self.assertEqual(res, "http://my.server/pic.gif") + + def test_set_avatar_url_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + ) + self.assertEqual(channel.code, 401, channel.result) + + def test_set_avatar_url_too_long(self): + """Attempts to set a stupid avatar_url should get a 400""" + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif" * 100}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_get_avatar_url_other(self): + res = self._get_avatar_url(self.other) + self.assertIsNone(res) + + def test_set_avatar_url_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.other,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def _get_displayname(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/displayname" % (name or self.owner,) + ) + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["displayname"] + + def _get_avatar_url(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/avatar_url" % (name or self.owner,) + ) + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body.get("avatar_url") + + +class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, hs): + # User owning the requested profile. + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + self.profile_url = "/profile/%s" % (self.owner) + + # User requesting the profile. + self.requester = self.register_user("requester", "pass") + self.requester_tok = self.login("requester", "pass") + + self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok) + + def test_no_auth(self): + self.try_fetch_profile(401) + + def test_not_in_shared_room(self): + self.ensure_requester_left_room() + + self.try_fetch_profile(403, access_token=self.requester_tok) + + def test_in_shared_room(self): + self.ensure_requester_left_room() + + self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) + + self.try_fetch_profile(200, self.requester_tok) + + def try_fetch_profile(self, expected_code, access_token=None): + self.request_profile(expected_code, access_token=access_token) + + self.request_profile( + expected_code, url_suffix="/displayname", access_token=access_token + ) + + self.request_profile( + expected_code, url_suffix="/avatar_url", access_token=access_token + ) + + def request_profile(self, expected_code, url_suffix="", access_token=None): + channel = self.make_request( + "GET", self.profile_url + url_suffix, access_token=access_token + ) + self.assertEqual(channel.code, expected_code, channel.result) + + def ensure_requester_left_room(self): + try: + self.helper.leave( + room=self.room_id, user=self.requester, tok=self.requester_tok + ) + except AssertionError: + # We don't care whether the leave request didn't return a 200 (e.g. + # if the user isn't already in the room), because we only want to + # make sure the user isn't in the room. + pass + + +class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, hs): + # User requesting the profile. + self.requester = self.register_user("requester", "pass") + self.requester_tok = self.login("requester", "pass") + + def test_can_lookup_own_profile(self): + """Tests that a user can lookup their own profile without having to be in a room + if 'require_auth_for_profile_requests' is set to true in the server's config. + """ + channel = self.make_request( + "GET", "/profile/" + self.requester, access_token=self.requester_tok + ) + self.assertEqual(channel.code, 200, channel.result) + + channel = self.make_request( + "GET", + "/profile/" + self.requester + "/displayname", + access_token=self.requester_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + channel = self.make_request( + "GET", + "/profile/" + self.requester + "/avatar_url", + access_token=self.requester_tok, + ) + self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py new file mode 100644 index 0000000000..d0ce91ccd9 --- /dev/null +++ b/tests/rest/client/test_push_rule_attrs.py @@ -0,0 +1,414 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse +from synapse.api.errors import Codes +from synapse.rest.client import login, push_rule, room + +from tests.unittest import HomeserverTestCase + + +class PushRuleAttributesTestCase(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + push_rule.register_servlets, + ] + hijack_auth = False + + def test_enabled_on_creation(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is enabled upon creation, even though a rule with that + ruleId existed previously and was disabled. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_on_recreation(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is enabled upon creation, even if a rule with that + ruleId existed previously and was disabled. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # disable the rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": False}, + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # check rule disabled + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], False) + + # DELETE the rule + channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.assertEqual(channel.code, 200) + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_disable(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is disabled and enabled when we ask for it. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # disable the rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": False}, + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # check rule disabled + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], False) + + # re-enable the rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": True}, + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # check rule enabled + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_404_when_get_non_existent(self): + """ + Tests that `enabled` gives 404 when the rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # check 404 for never-heard-of rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + + # DELETE the rule + channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.assertEqual(channel.code, 200) + + # check 404 for deleted rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_get_non_existent_server_rule(self): + """ + Tests that `enabled` gives 404 when the server-default rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # check 404 for never-heard-of rule + channel = self.make_request( + "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_put_non_existent_rule(self): + """ + Tests that `enabled` gives 404 when we put to a rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": True}, + access_token=token, + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_put_non_existent_server_rule(self): + """ + Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/.m.muahahah/enabled", + {"enabled": True}, + access_token=token, + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_get(self): + """ + Tests that `actions` gives you what you expect on a fresh rule. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # GET actions for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/actions", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] + ) + + def test_actions_put(self): + """ + Tests that PUT on actions updates the value you'd get from GET. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # change the rule actions + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # GET actions for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/actions", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["actions"], ["dont_notify"]) + + def test_actions_404_when_get_non_existent(self): + """ + Tests that `actions` gives 404 when the rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # check 404 for never-heard-of rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # DELETE the rule + channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.assertEqual(channel.code, 200) + + # check 404 for deleted rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_get_non_existent_server_rule(self): + """ + Tests that `actions` gives 404 when the server-default rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # check 404 for never-heard-of rule + channel = self.make_request( + "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_put_non_existent_rule(self): + """ + Tests that `actions` gives 404 when putting to a rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_put_non_existent_server_rule(self): + """ + Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/.m.muahahah/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py new file mode 100644 index 0000000000..fecda037a5 --- /dev/null +++ b/tests/rest/client/test_register.py @@ -0,0 +1,746 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import json +import os + +import pkg_resources + +import synapse.rest.admin +from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType +from synapse.api.errors import Codes +from synapse.appservice import ApplicationService +from synapse.rest.client import account, account_validity, login, logout, register, sync + +from tests import unittest +from tests.unittest import override_config + + +class RegisterRestServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + register.register_servlets, + synapse.rest.admin.register_servlets, + ] + url = b"/_matrix/client/r0/register" + + def default_config(self): + config = super().default_config() + config["allow_guest_access"] = True + return config + + def test_POST_appservice_registration_valid(self): + user_id = "@as_user_kermit:test" + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + self.hs.config.server_name, + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + self.hs.get_datastore().services_cache.append(appservice) + request_data = json.dumps( + {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} + ) + + channel = self.make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + + self.assertEquals(channel.result["code"], b"200", channel.result) + det_data = {"user_id": user_id, "home_server": self.hs.hostname} + self.assertDictContainsSubset(det_data, channel.json_body) + + def test_POST_appservice_registration_no_type(self): + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + self.hs.config.server_name, + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + self.hs.get_datastore().services_cache.append(appservice) + request_data = json.dumps({"username": "as_user_kermit"}) + + channel = self.make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + + self.assertEquals(channel.result["code"], b"400", channel.result) + + def test_POST_appservice_registration_invalid(self): + self.appservice = None # no application service exists + request_data = json.dumps( + {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} + ) + channel = self.make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + + self.assertEquals(channel.result["code"], b"401", channel.result) + + def test_POST_bad_password(self): + request_data = json.dumps({"username": "kermit", "password": 666}) + channel = self.make_request(b"POST", self.url, request_data) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals(channel.json_body["error"], "Invalid password") + + def test_POST_bad_username(self): + request_data = json.dumps({"username": 777, "password": "monkey"}) + channel = self.make_request(b"POST", self.url, request_data) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals(channel.json_body["error"], "Invalid username") + + def test_POST_user_valid(self): + user_id = "@kermit:test" + device_id = "frogfone" + params = { + "username": "kermit", + "password": "monkey", + "device_id": device_id, + "auth": {"type": LoginType.DUMMY}, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + + det_data = { + "user_id": user_id, + "home_server": self.hs.hostname, + "device_id": device_id, + } + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, channel.json_body) + + @override_config({"enable_registration": False}) + def test_POST_disabled_registration(self): + request_data = json.dumps({"username": "kermit", "password": "monkey"}) + self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) + + channel = self.make_request(b"POST", self.url, request_data) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals(channel.json_body["error"], "Registration has been disabled") + self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") + + def test_POST_guest_registration(self): + self.hs.config.macaroon_secret_key = "test" + self.hs.config.allow_guest_access = True + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, channel.json_body) + + def test_POST_disabled_guest_registration(self): + self.hs.config.allow_guest_access = False + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals(channel.json_body["error"], "Guest access is disabled") + + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) + def test_POST_ratelimiting_guest(self): + for i in range(0, 6): + url = self.url + b"?kind=guest" + channel = self.make_request(b"POST", url, b"{}") + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) + def test_POST_ratelimiting(self): + for i in range(0, 6): + params = { + "username": "kermit" + str(i), + "password": "monkey", + "device_id": "frogfone", + "auth": {"type": LoginType.DUMMY}, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_advertised_flows(self): + channel = self.make_request(b"POST", self.url, b"{}") + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we only expect the dummy flow + self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows)) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "enable_registration_captcha": True, + "user_consent": { + "version": "1", + "template_dir": "/", + "require_at_registration": True, + }, + "account_threepid_delegates": { + "email": "https://id_server", + "msisdn": "https://id_server", + }, + } + ) + def test_advertised_flows_captcha_and_terms_and_3pids(self): + channel = self.make_request(b"POST", self.url, b"{}") + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + self.assertCountEqual( + [ + ["m.login.recaptcha", "m.login.terms", "m.login.dummy"], + ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"], + ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"], + [ + "m.login.recaptcha", + "m.login.terms", + "m.login.msisdn", + "m.login.email.identity", + ], + ], + (f["stages"] for f in flows), + ) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "registrations_require_3pid": ["email"], + "disable_msisdn_registration": True, + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_advertised_flows_no_msisdn_email_required(self): + channel = self.make_request(b"POST", self.url, b"{}") + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we expect all four combinations of 3pid + self.assertCountEqual( + [["m.login.email.identity"]], (f["stages"] for f in flows) + ) + + @unittest.override_config( + { + "request_token_inhibit_3pid_errors": True, + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_request_token_existing_email_inhibit_error(self): + """Test that requesting a token via this endpoint doesn't leak existing + associations if configured that way. + """ + user_id = self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.hs.get_datastore().user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.assertEquals(200, channel.code, channel.result) + + self.assertIsNotNone(channel.json_body.get("sid")) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_reject_invalid_email(self): + """Check that bad emails are rejected""" + + # Test for email with multiple @ + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + # Check error to ensure that we're not erroring due to a bug in the test. + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + # Test for email with no @ + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": "email", "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + # Test for super long email + email = "a@" + "a" * 1000 + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + +class AccountValidityTestCase(unittest.HomeserverTestCase): + + servlets = [ + register.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + sync.register_servlets, + logout.register_servlets, + account_validity.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + # Test for account expiring after a week. + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + } + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def test_validity_period(self): + self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + channel = self.make_request(b"GET", "/sync", access_token=tok) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) + + channel = self.make_request(b"GET", "/sync", access_token=tok) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result + ) + + def test_manual_renewal(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) + + # If we register the admin user at the beginning of the test, it will + # expire at the same time as the normal user and the renewal request + # will be denied. + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_synapse/admin/v1/account_validity/validity" + params = {"user_id": user_id} + request_data = json.dumps(params) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + channel = self.make_request(b"GET", "/sync", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_manual_expire(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_synapse/admin/v1/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + channel = self.make_request(b"GET", "/sync", access_token=tok) + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result + ) + + def test_logging_out_expired_user(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_synapse/admin/v1/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Try to log the user out + channel = self.make_request(b"POST", "/logout", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Log the user in again (allowed for expired accounts) + tok = self.login("kermit", "monkey") + + # Try to log out all of the user's sessions + channel = self.make_request(b"POST", "/logout/all", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + +class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): + + servlets = [ + register.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + sync.register_servlets, + account_validity.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Test for account expiring after a week and renewal emails being sent 2 + # days before expiry. + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + "renew_at": 172800000, # Time in ms for 2 days + "renew_by_email_enabled": True, + "renew_email_subject": "Renew your account", + "account_renewed_html_path": "account_renewed.html", + "invalid_token_html_path": "invalid_token.html", + } + + # Email config. + + config["email"] = { + "enable_notifs": True, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "expiry_template_html": "notice_expiry.html", + "expiry_template_text": "notice_expiry.txt", + "notif_template_html": "notif_mail.html", + "notif_template_text": "notif_mail.txt", + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "aaa" + + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail(*args, **kwargs): + self.email_attempts.append((args, kwargs)) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail + + self.store = self.hs.get_datastore() + + return self.hs + + def test_renewal_email(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + + # Move 5 days forward. This should trigger a renewal email to be sent. + self.reactor.advance(datetime.timedelta(days=5).total_seconds()) + self.assertEqual(len(self.email_attempts), 1) + + # Retrieving the URL from the email is too much pain for now, so we + # retrieve the token from the DB. + renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) + url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect on a successful renewal. + expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render( + expiration_ts=expiration_ts + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + # Move 1 day forward. Try to renew with the same token again. + url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect when reusing a + # token. The account expiration date should not have changed. + expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + # Move 3 days forward. If the renewal failed, every authed request with + # our access token should be denied from now, otherwise they should + # succeed. + self.reactor.advance(datetime.timedelta(days=3).total_seconds()) + channel = self.make_request(b"GET", "/sync", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_renewal_invalid_token(self): + # Hit the renewal endpoint with an invalid token and check that it behaves as + # expected, i.e. that it responds with 404 Not Found and the correct HTML. + url = "/_matrix/client/unstable/account_validity/renew?token=123" + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"404", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect when using an + # invalid/unknown token. + expected_html = ( + self.hs.config.account_validity.account_validity_invalid_token_template.render() + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + def test_manual_email_send(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + channel = self.make_request( + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", + access_token=tok, + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.assertEqual(len(self.email_attempts), 1) + + def test_deactivated_user(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + + request_data = json.dumps( + { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + } + ) + channel = self.make_request( + "POST", "account/deactivate", request_data, access_token=tok + ) + self.assertEqual(channel.code, 200) + + self.reactor.advance(datetime.timedelta(days=8).total_seconds()) + + self.assertEqual(len(self.email_attempts), 0) + + def create_user(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + # We need to manually add an email address otherwise the handler will do + # nothing. + now = self.hs.get_clock().time_msec() + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) + return user_id, tok + + def test_manual_email_send_expired_account(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + # We need to manually add an email address otherwise the handler will do + # nothing. + now = self.hs.get_clock().time_msec() + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) + + # Make the account expire. + self.reactor.advance(datetime.timedelta(days=8).total_seconds()) + + # Ignore all emails sent by the automatic background task and only focus on the + # ones sent manually. + self.email_attempts = [] + + # Test that we're still able to manually trigger a mail to be sent. + channel = self.make_request( + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", + access_token=tok, + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.assertEqual(len(self.email_attempts), 1) + + +class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): + + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] + + def make_homeserver(self, reactor, clock): + self.validity_period = 10 + self.max_delta = self.validity_period * 10.0 / 100.0 + + config = self.default_config() + + config["enable_registration"] = True + config["account_validity"] = {"enabled": False} + + self.hs = self.setup_test_homeserver(config=config) + + # We need to set these directly, instead of in the homeserver config dict above. + # This is due to account validity-related config options not being read by + # Synapse when account_validity.enabled is False. + self.hs.get_datastore()._account_validity_period = self.validity_period + self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta + + self.store = self.hs.get_datastore() + + return self.hs + + def test_background_job(self): + """ + Tests the same thing as test_background_job, except that it sets the + startup_job_max_delta parameter and checks that the expiration date is within the + allowed range. + """ + user_id = self.register_user("kermit_delta", "user") + + self.hs.config.account_validity.startup_job_max_delta = self.max_delta + + now_ms = self.hs.get_clock().time_msec() + self.get_success(self.store._set_expiration_date_when_missing()) + + res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + + self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) + self.assertLessEqual(res, now_ms + self.validity_period) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py new file mode 100644 index 0000000000..02b5e9a8d0 --- /dev/null +++ b/tests/rest/client/test_relations.py @@ -0,0 +1,724 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import json +import urllib +from typing import Optional + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.rest import admin +from synapse.rest.client import login, register, relations, room + +from tests import unittest + + +class RelationsTestCase(unittest.HomeserverTestCase): + servlets = [ + relations.register_servlets, + room.register_servlets, + login.register_servlets, + register.register_servlets, + admin.register_servlets_for_client_rest_resource, + ] + hijack_auth = False + + def make_homeserver(self, reactor, clock): + # We need to enable msc1849 support for aggregations + config = self.default_config() + config["experimental_msc1849_support_enabled"] = True + + # We enable frozen dicts as relations/edits change event contents, so we + # want to test that we don't modify the events in the caches. + config["use_frozen_dicts"] = True + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + self.user_id, self.user_token = self._create_user("alice") + self.user2_id, self.user2_token = self._create_user("bob") + + self.room = self.helper.create_room_as(self.user_id, tok=self.user_token) + self.helper.join(self.room, user=self.user2_id, tok=self.user2_token) + res = self.helper.send(self.room, body="Hi!", tok=self.user_token) + self.parent_id = res["event_id"] + + def test_send_relation(self): + """Tests that sending a relation using the new /send_relation works + creates the right shape of event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + self.assertEquals(200, channel.code, channel.json_body) + + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, event_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assert_dict( + { + "type": "m.reaction", + "sender": self.user_id, + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "key": "👍", + "rel_type": RelationTypes.ANNOTATION, + } + }, + }, + channel.json_body, + ) + + def test_deny_membership(self): + """Test that we deny relations on membership events""" + channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) + self.assertEquals(400, channel.code, channel.json_body) + + def test_deny_double_react(self): + """Test that we deny relations on membership events""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(400, channel.code, channel.json_body) + + def test_basic_paginate_relations(self): + """Tests that calling pagination API correctly the latest relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + annotation_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the full + # relation event we sent above. + self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, + channel.json_body["chunk"][0], + ) + + # We also expect to get the original event (the id of which is self.parent_id) + self.assertEquals( + channel.json_body["original_event"]["event_id"], self.parent_id + ) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + self.assertIsInstance( + channel.json_body.get("next_batch"), str, channel.json_body + ) + + def test_repeated_paginate_relations(self): + """Test that if we paginate using a limit and tokens then we get the + expected events. + """ + + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + prev_token = None + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation_pagination_groups(self): + """Test that we can paginate annotation groups correctly.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} + for key in itertools.chain.from_iterable( + itertools.repeat(key, num) for key, num in sent_groups.items() + ): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key=key, + access_token=access_tokens[idx], + ) + self.assertEquals(200, channel.code, channel.json_body) + + idx += 1 + idx %= len(access_tokens) + + prev_token = None + found_groups = {} + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + for groups in channel.json_body["chunk"]: + # We only expect reactions + self.assertEqual(groups["type"], "m.reaction", channel.json_body) + + # We should only see each key once + self.assertNotIn(groups["key"], found_groups, channel.json_body) + + found_groups[groups["key"]] = groups["count"] + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEquals(sent_groups, found_groups) + + def test_aggregation_pagination_within_group(self): + """Test that we can paginate within an annotation group.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key="👍", + access_token=access_tokens[idx], + ) + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + idx += 1 + + # Also send a different type of reaction so that we test we don't see it + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self.assertEquals(200, channel.code, channel.json_body) + + prev_token = None + found_event_ids = [] + encoded_key = urllib.parse.quote_plus("👍".encode()) + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s" + "/aggregations/%s/%s/m.reaction/%s?limit=1%s" + % ( + self.room, + self.parent_id, + RelationTypes.ANNOTATION, + encoded_key, + from_token, + ), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation(self): + """Test that annotations get correctly aggregated.""" + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + ) + + def test_aggregation_redactions(self): + """Test that annotations get correctly aggregated after a redaction.""" + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + to_redact_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Now lets redact one of the 'a' reactions + channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), + access_token=self.user_token, + content={}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + ) + + def test_aggregation_must_be_annotation(self): + """Test that aggregations must be annotations.""" + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" + % (self.room, self.parent_id, RelationTypes.REPLACE), + access_token=self.user_token, + ) + self.assertEquals(400, channel.code, channel.json_body) + + def test_aggregation_get_event(self): + """Test that annotations and references get correctly bundled when + getting the parent event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_2 = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + RelationTypes.REFERENCE: { + "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] + }, + }, + ) + + def test_edit(self): + """Test that a simple edit works.""" + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_multi_edit(self): + """Test that multiple edits, including attempts by people who + shouldn't be allowed, are correctly handled. + """ + + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "First edit"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message.WRONG_TYPE", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_edit_reply(self): + """Test that editing a reply works.""" + + # Create a reply to edit. + channel = self._send_relation( + RelationTypes.REFERENCE, + "m.room.message", + content={"msgtype": "m.text", "body": "A reply!"}, + ) + self.assertEquals(200, channel.code, channel.json_body) + reply = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=reply, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, reply), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to see the new body in the dict, as well as the reference + # metadata sill intact. + self.assertDictContainsSubset(new_body, channel.json_body["content"]) + self.assertDictContainsSubset( + { + "m.relates_to": { + "event_id": self.parent_id, + "key": None, + "rel_type": "m.reference", + } + }, + channel.json_body["content"], + ) + + # We expect that the edit relation appears in the unsigned relations + # section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_relations_redaction_redacts_edits(self): + """Test that edits of an event are redacted when the original event + is redacted. + """ + # Send a new event + res = self.helper.send(self.room, body="Heyo!", tok=self.user_token) + original_event_id = res["event_id"] + + # Add a relation + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + parent_id=original_event_id, + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "First edit"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Check the relation is returned + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" + % (self.room, original_event_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertIn("chunk", channel.json_body) + self.assertEquals(len(channel.json_body["chunk"]), 1) + + # Redact the original event + channel = self.make_request( + "PUT", + "/rooms/%s/redact/%s/%s" + % (self.room, original_event_id, "test_relations_redaction_redacts_edits"), + access_token=self.user_token, + content="{}", + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Try to check for remaining m.replace relations + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" + % (self.room, original_event_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Check that no relations are returned + self.assertIn("chunk", channel.json_body) + self.assertEquals(channel.json_body["chunk"], []) + + def test_aggregations_redaction_prevents_access_to_aggregations(self): + """Test that annotations of an event are redacted when the original event + is redacted. + """ + # Send a new event + res = self.helper.send(self.room, body="Hello!", tok=self.user_token) + original_event_id = res["event_id"] + + # Add a relation + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Redact the original + channel = self.make_request( + "PUT", + "/rooms/%s/redact/%s/%s" + % ( + self.room, + original_event_id, + "test_aggregations_redaction_prevents_access_to_aggregations", + ), + access_token=self.user_token, + content="{}", + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Check that aggregations returns zero + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction" + % (self.room, original_event_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertIn("chunk", channel.json_body) + self.assertEquals(channel.json_body["chunk"], []) + + def _send_relation( + self, + relation_type, + event_type, + key=None, + content: Optional[dict] = None, + access_token=None, + parent_id=None, + ): + """Helper function to send a relation pointing at `self.parent_id` + + Args: + relation_type (str): One of `RelationTypes` + event_type (str): The type of the event to create + parent_id (str): The event_id this relation relates to. If None, then self.parent_id + key (str|None): The aggregation key used for m.annotation relation + type. + content(dict|None): The content of the created event. + access_token (str|None): The access token used to send the relation, + defaults to `self.user_token` + + Returns: + FakeChannel + """ + if not access_token: + access_token = self.user_token + + query = "" + if key: + query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8")) + + original_id = parent_id if parent_id else self.parent_id + + channel = self.make_request( + "POST", + "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" + % (self.room, original_id, relation_type, event_type, query), + json.dumps(content or {}).encode("utf-8"), + access_token=access_token, + ) + return channel + + def _create_user(self, localpart): + user_id = self.register_user(localpart, "abc123") + access_token = self.login(localpart, "abc123") + + return user_id, access_token diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py new file mode 100644 index 0000000000..ee6b0b9ebf --- /dev/null +++ b/tests/rest/client/test_report_event.py @@ -0,0 +1,82 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import synapse.rest.admin +from synapse.rest.client import login, report_event, room + +from tests import unittest + + +class ReportEventTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + report_event.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) + resp = self.helper.send(self.room_id, tok=self.admin_user_tok) + self.event_id = resp["event_id"] + self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" + + def test_reason_str_and_score_int(self): + data = {"reason": "this makes me sad", "score": -100} + self._assert_status(200, data) + + def test_no_reason(self): + data = {"score": 0} + self._assert_status(200, data) + + def test_no_score(self): + data = {"reason": "this makes me sad"} + self._assert_status(200, data) + + def test_no_reason_and_no_score(self): + data = {} + self._assert_status(200, data) + + def test_reason_int_and_score_str(self): + data = {"reason": 10, "score": "string"} + self._assert_status(400, data) + + def test_reason_zero_and_score_blank(self): + data = {"reason": 0, "score": ""} + self._assert_status(400, data) + + def test_reason_and_score_null(self): + data = {"reason": None, "score": None} + self._assert_status(400, data) + + def _assert_status(self, response_status, data): + channel = self.make_request( + "POST", + self.report_path, + json.dumps(data), + access_token=self.other_user_tok, + ) + self.assertEqual( + response_status, int(channel.result["code"]), msg=channel.result["body"] + ) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py new file mode 100644 index 0000000000..0c9cbb9aff --- /dev/null +++ b/tests/rest/client/test_rooms.py @@ -0,0 +1,2150 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests REST events for /rooms paths.""" + +import json +from typing import Iterable +from unittest.mock import Mock, call +from urllib import parse as urlparse + +from twisted.internet import defer + +import synapse.rest.admin +from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.errors import HttpResponseException +from synapse.handlers.pagination import PurgeStatus +from synapse.rest import admin +from synapse.rest.client import account, directory, login, profile, room +from synapse.types import JsonDict, RoomAlias, UserID, create_requester +from synapse.util.stringutils import random_string + +from tests import unittest +from tests.test_utils import make_awaitable + +PATH_PREFIX = b"/_matrix/client/api/v1" + + +class RoomBase(unittest.HomeserverTestCase): + rmcreator_id = None + + servlets = [room.register_servlets, room.register_deprecated_servlets] + + def make_homeserver(self, reactor, clock): + + self.hs = self.setup_test_homeserver( + "red", + federation_http_client=None, + federation_client=Mock(), + ) + + self.hs.get_federation_handler = Mock() + self.hs.get_federation_handler.return_value.maybe_backfill = Mock( + return_value=make_awaitable(None) + ) + + async def _insert_client_ip(*args, **kwargs): + return None + + self.hs.get_datastore().insert_client_ip = _insert_client_ip + + return self.hs + + +class RoomPermissionsTestCase(RoomBase): + """Tests room permissions.""" + + user_id = "@sid1:red" + rmcreator_id = "@notme:red" + + def prepare(self, reactor, clock, hs): + + self.helper.auth_user_id = self.rmcreator_id + # create some rooms under the name rmcreator_id + self.uncreated_rmid = "!aa:test" + self.created_rmid = self.helper.create_room_as( + self.rmcreator_id, is_public=False + ) + self.created_public_rmid = self.helper.create_room_as( + self.rmcreator_id, is_public=True + ) + + # send a message in one of the rooms + self.created_rmid_msg_path = ( + "rooms/%s/send/m.room.message/a1" % (self.created_rmid) + ).encode("ascii") + channel = self.make_request( + "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' + ) + self.assertEquals(200, channel.code, channel.result) + + # set topic for public room + channel = self.make_request( + "PUT", + ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), + b'{"topic":"Public Room Topic"}', + ) + self.assertEquals(200, channel.code, channel.result) + + # auth as user_id now + self.helper.auth_user_id = self.user_id + + def test_can_do_action(self): + msg_content = b'{"msgtype":"m.text","body":"hello"}' + + seq = iter(range(100)) + + def send_msg_path(): + return "/rooms/%s/send/m.room.message/mid%s" % ( + self.created_rmid, + str(next(seq)), + ) + + # send message in uncreated room, expect 403 + channel = self.make_request( + "PUT", + "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), + msg_content, + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # send message in created room not joined (no state), expect 403 + channel = self.make_request("PUT", send_msg_path(), msg_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # send message in created room and invited, expect 403 + self.helper.invite( + room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id + ) + channel = self.make_request("PUT", send_msg_path(), msg_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # send message in created room and joined, expect 200 + self.helper.join(room=self.created_rmid, user=self.user_id) + channel = self.make_request("PUT", send_msg_path(), msg_content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # send message in created room and left, expect 403 + self.helper.leave(room=self.created_rmid, user=self.user_id) + channel = self.make_request("PUT", send_msg_path(), msg_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def test_topic_perms(self): + topic_content = b'{"topic":"My Topic Name"}' + topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid + + # set/get topic in uncreated room, expect 403 + channel = self.make_request( + "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + channel = self.make_request( + "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # set/get topic in created PRIVATE room not joined, expect 403 + channel = self.make_request("PUT", topic_path, topic_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + channel = self.make_request("GET", topic_path) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # set topic in created PRIVATE room and invited, expect 403 + self.helper.invite( + room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id + ) + channel = self.make_request("PUT", topic_path, topic_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # get topic in created PRIVATE room and invited, expect 403 + channel = self.make_request("GET", topic_path) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # set/get topic in created PRIVATE room and joined, expect 200 + self.helper.join(room=self.created_rmid, user=self.user_id) + + # Only room ops can set topic by default + self.helper.auth_user_id = self.rmcreator_id + channel = self.make_request("PUT", topic_path, topic_content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.helper.auth_user_id = self.user_id + + channel = self.make_request("GET", topic_path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) + + # set/get topic in created PRIVATE room and left, expect 403 + self.helper.leave(room=self.created_rmid, user=self.user_id) + channel = self.make_request("PUT", topic_path, topic_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + channel = self.make_request("GET", topic_path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # get topic in PUBLIC room, not joined, expect 403 + channel = self.make_request( + "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # set topic in PUBLIC room, not joined, expect 403 + channel = self.make_request( + "PUT", + "/rooms/%s/state/m.room.topic" % self.created_public_rmid, + topic_content, + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def _test_get_membership( + self, room=None, members: Iterable = frozenset(), expect_code=None + ): + for member in members: + path = "/rooms/%s/state/m.room.member/%s" % (room, member) + channel = self.make_request("GET", path) + self.assertEquals(expect_code, channel.code) + + def test_membership_basic_room_perms(self): + # === room does not exist === + room = self.uncreated_rmid + # get membership of self, get membership of other, uncreated room + # expect all 403s + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) + + # trying to invite people to this room should 403 + self.helper.invite( + room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 + ) + + # set [invite/join/left] of self, set [invite/join/left] of other, + # expect all 404s because room doesn't exist on any server + for usr in [self.user_id, self.rmcreator_id]: + self.helper.join(room=room, user=usr, expect_code=404) + self.helper.leave(room=room, user=usr, expect_code=404) + + def test_membership_private_room_perms(self): + room = self.created_rmid + # get membership of self, get membership of other, private room + invite + # expect all 403s + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) + + # get membership of self, get membership of other, private room + joined + # expect all 200s + self.helper.join(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) + + # get membership of self, get membership of other, private room + left + # expect all 200s + self.helper.leave(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) + + def test_membership_public_room_perms(self): + room = self.created_public_rmid + # get membership of self, get membership of other, public room + invite + # expect 403 + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) + + # get membership of self, get membership of other, public room + joined + # expect all 200s + self.helper.join(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) + + # get membership of self, get membership of other, public room + left + # expect 200. + self.helper.leave(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) + + def test_invited_permissions(self): + room = self.created_rmid + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + + # set [invite/join/left] of other user, expect 403s + self.helper.invite( + room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 + ) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.JOIN, + expect_code=403, + ) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.LEAVE, + expect_code=403, + ) + + def test_joined_permissions(self): + room = self.created_rmid + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.join(room=room, user=self.user_id) + + # set invited of self, expect 403 + self.helper.invite( + room=room, src=self.user_id, targ=self.user_id, expect_code=403 + ) + + # set joined of self, expect 200 (NOOP) + self.helper.join(room=room, user=self.user_id) + + other = "@burgundy:red" + # set invited of other, expect 200 + self.helper.invite(room=room, src=self.user_id, targ=other, expect_code=200) + + # set joined of other, expect 403 + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.JOIN, + expect_code=403, + ) + + # set left of other, expect 403 + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.LEAVE, + expect_code=403, + ) + + # set left of self, expect 200 + self.helper.leave(room=room, user=self.user_id) + + def test_leave_permissions(self): + room = self.created_rmid + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.join(room=room, user=self.user_id) + self.helper.leave(room=room, user=self.user_id) + + # set [invite/join/left] of self, set [invite/join/left] of other, + # expect all 403s + for usr in [self.user_id, self.rmcreator_id]: + self.helper.change_membership( + room=room, + src=self.user_id, + targ=usr, + membership=Membership.INVITE, + expect_code=403, + ) + + self.helper.change_membership( + room=room, + src=self.user_id, + targ=usr, + membership=Membership.JOIN, + expect_code=403, + ) + + # It is always valid to LEAVE if you've already left (currently.) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.LEAVE, + expect_code=403, + ) + + +class RoomsMemberListTestCase(RoomBase): + """Tests /rooms/$room_id/members/list REST events.""" + + user_id = "@sid1:red" + + def test_get_member_list(self): + room_id = self.helper.create_room_as(self.user_id) + channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + def test_get_member_list_no_room(self): + channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def test_get_member_list_no_permission(self): + room_id = self.helper.create_room_as("@some_other_guy:red") + channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def test_get_member_list_mixed_memberships(self): + room_creator = "@some_other_guy:red" + room_id = self.helper.create_room_as(room_creator) + room_path = "/rooms/%s/members" % room_id + self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) + # can't see list if you're just invited. + channel = self.make_request("GET", room_path) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + self.helper.join(room=room_id, user=self.user_id) + # can see list now joined + channel = self.make_request("GET", room_path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + self.helper.leave(room=room_id, user=self.user_id) + # can see old list once left + channel = self.make_request("GET", room_path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + +class RoomsCreateTestCase(RoomBase): + """Tests /rooms and /rooms/$room_id REST events.""" + + user_id = "@sid1:red" + + def test_post_room_no_keys(self): + # POST with no config keys, expect new room id + channel = self.make_request("POST", "/createRoom", "{}") + + self.assertEquals(200, channel.code, channel.result) + self.assertTrue("room_id" in channel.json_body) + + def test_post_room_visibility_key(self): + # POST with visibility config key, expect new room id + channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') + self.assertEquals(200, channel.code) + self.assertTrue("room_id" in channel.json_body) + + def test_post_room_custom_key(self): + # POST with custom config keys, expect new room id + channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') + self.assertEquals(200, channel.code) + self.assertTrue("room_id" in channel.json_body) + + def test_post_room_known_and_unknown_keys(self): + # POST with custom + known config keys, expect new room id + channel = self.make_request( + "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' + ) + self.assertEquals(200, channel.code) + self.assertTrue("room_id" in channel.json_body) + + def test_post_room_invalid_content(self): + # POST with invalid content / paths, expect 400 + channel = self.make_request("POST", "/createRoom", b'{"visibili') + self.assertEquals(400, channel.code) + + channel = self.make_request("POST", "/createRoom", b'["hello"]') + self.assertEquals(400, channel.code) + + def test_post_room_invitees_invalid_mxid(self): + # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 + # Note the trailing space in the MXID here! + channel = self.make_request( + "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' + ) + self.assertEquals(400, channel.code) + + @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) + def test_post_room_invitees_ratelimit(self): + """Test that invites sent when creating a room are ratelimited by a RateLimiter, + which ratelimits them correctly, including by not limiting when the requester is + exempt from ratelimiting. + """ + + # Build the request's content. We use local MXIDs because invites over federation + # are more difficult to mock. + content = json.dumps( + { + "invite": [ + "@alice1:red", + "@alice2:red", + "@alice3:red", + "@alice4:red", + ] + } + ).encode("utf8") + + # Test that the invites are correctly ratelimited. + channel = self.make_request("POST", "/createRoom", content) + self.assertEqual(400, channel.code) + self.assertEqual( + "Cannot invite so many users at once", + channel.json_body["error"], + ) + + # Add the current user to the ratelimit overrides, allowing them no ratelimiting. + self.get_success( + self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0) + ) + + # Test that the invites aren't ratelimited anymore. + channel = self.make_request("POST", "/createRoom", content) + self.assertEqual(200, channel.code) + + +class RoomTopicTestCase(RoomBase): + """Tests /rooms/$room_id/topic REST events.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + # create the room + self.room_id = self.helper.create_room_as(self.user_id) + self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) + + def test_invalid_puts(self): + # missing keys or invalid json + channel = self.make_request("PUT", self.path, "{}") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", self.path, '{"_name":"bo"}') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", self.path, '{"nao') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request( + "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' + ) + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", self.path, "text only") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", self.path, "") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + # valid key, wrong type + content = '{"topic":["Topic name"]}' + channel = self.make_request("PUT", self.path, content) + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + def test_rooms_topic(self): + # nothing should be there + channel = self.make_request("GET", self.path) + self.assertEquals(404, channel.code, msg=channel.result["body"]) + + # valid put + content = '{"topic":"Topic name"}' + channel = self.make_request("PUT", self.path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # valid get + channel = self.make_request("GET", self.path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assert_dict(json.loads(content), channel.json_body) + + def test_rooms_topic_with_extra_keys(self): + # valid put with extra keys + content = '{"topic":"Seasons","subtopic":"Summer"}' + channel = self.make_request("PUT", self.path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # valid get + channel = self.make_request("GET", self.path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assert_dict(json.loads(content), channel.json_body) + + +class RoomMemberStateTestCase(RoomBase): + """Tests /rooms/$room_id/members/$user_id/state REST events.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_invalid_puts(self): + path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) + # missing keys or invalid json + channel = self.make_request("PUT", path, "{}") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, '{"_name":"bo"}') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, '{"nao') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, "text only") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, "") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + # valid keys, wrong types + content = '{"membership":["%s","%s","%s"]}' % ( + Membership.INVITE, + Membership.JOIN, + Membership.LEAVE, + ) + channel = self.make_request("PUT", path, content.encode("ascii")) + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + def test_rooms_members_self(self): + path = "/rooms/%s/state/m.room.member/%s" % ( + urlparse.quote(self.room_id), + self.user_id, + ) + + # valid join message (NOOP since we made the room) + content = '{"membership":"%s"}' % Membership.JOIN + channel = self.make_request("PUT", path, content.encode("ascii")) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + channel = self.make_request("GET", path, None) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + expected_response = {"membership": Membership.JOIN} + self.assertEquals(expected_response, channel.json_body) + + def test_rooms_members_other(self): + self.other_id = "@zzsid1:red" + path = "/rooms/%s/state/m.room.member/%s" % ( + urlparse.quote(self.room_id), + self.other_id, + ) + + # valid invite message + content = '{"membership":"%s"}' % Membership.INVITE + channel = self.make_request("PUT", path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + channel = self.make_request("GET", path, None) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEquals(json.loads(content), channel.json_body) + + def test_rooms_members_other_custom_keys(self): + self.other_id = "@zzsid1:red" + path = "/rooms/%s/state/m.room.member/%s" % ( + urlparse.quote(self.room_id), + self.other_id, + ) + + # valid invite message with custom key + content = '{"membership":"%s","invite_text":"%s"}' % ( + Membership.INVITE, + "Join us!", + ) + channel = self.make_request("PUT", path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + channel = self.make_request("GET", path, None) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEquals(json.loads(content), channel.json_body) + + +class RoomInviteRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + admin.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_rooms_ratelimit(self): + """Tests that invites in a room are actually rate-limited.""" + room_id = self.helper.create_room_as(self.user_id) + + for i in range(3): + self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,)) + + self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429) + + @unittest.override_config( + {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_users_ratelimit(self): + """Tests that invites to a specific user are actually rate-limited.""" + + for _ in range(3): + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red") + + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429) + + +class RoomJoinRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + admin.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit(self): + """Tests that local joins are actually rate-limited.""" + for _ in range(3): + self.helper.create_room_as(self.user_id) + + self.helper.create_room_as(self.user_id, expect_code=429) + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit_profile_change(self): + """Tests that sending a profile update into all of the user's joined rooms isn't + rate-limited by the rate-limiter on joins.""" + + # Create and join as many rooms as the rate-limiting config allows in a second. + room_ids = [ + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + ] + # Let some time for the rate-limiter to forget about our multi-join. + self.reactor.advance(2) + # Add one to make sure we're joined to more rooms than the config allows us to + # join in a second. + room_ids.append(self.helper.create_room_as(self.user_id)) + + # Create a profile for the user, since it hasn't been done on registration. + store = self.hs.get_datastore() + self.get_success( + store.create_profile(UserID.from_string(self.user_id).localpart) + ) + + # Update the display name for the user. + path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id + channel = self.make_request("PUT", path, {"displayname": "John Doe"}) + self.assertEquals(channel.code, 200, channel.json_body) + + # Check that all the rooms have been sent a profile update into. + for room_id in room_ids: + path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( + room_id, + self.user_id, + ) + + channel = self.make_request("GET", path) + self.assertEquals(channel.code, 200) + + self.assertIn("displayname", channel.json_body) + self.assertEquals(channel.json_body["displayname"], "John Doe") + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit_idempotent(self): + """Tests that the room join endpoints remain idempotent despite rate-limiting + on room joins.""" + room_id = self.helper.create_room_as(self.user_id) + + # Let's test both paths to be sure. + paths_to_test = [ + "/_matrix/client/r0/rooms/%s/join", + "/_matrix/client/r0/join/%s", + ] + + for path in paths_to_test: + # Make sure we send more requests than the rate-limiting config would allow + # if all of these requests ended up joining the user to a room. + for _ in range(4): + channel = self.make_request("POST", path % room_id, {}) + self.assertEquals(channel.code, 200) + + @unittest.override_config( + { + "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}, + "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"], + "autocreate_auto_join_rooms": True, + }, + ) + def test_autojoin_rooms(self): + user_id = self.register_user("testuser", "password") + + # Check that the new user successfully joined the four rooms + rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 4) + + +class RoomMessagesTestCase(RoomBase): + """Tests /rooms/$room_id/messages/$user_id/$msg_id REST events.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_invalid_puts(self): + path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) + # missing keys or invalid json + channel = self.make_request("PUT", path, b"{}") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b'{"_name":"bo"}') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b'{"nao') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b"text only") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b"") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + def test_rooms_messages_sent(self): + path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) + + content = b'{"body":"test","msgtype":{"type":"a"}}' + channel = self.make_request("PUT", path, content) + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + # custom message types + content = b'{"body":"test","msgtype":"test.custom.text"}' + channel = self.make_request("PUT", path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # m.text message type + path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) + content = b'{"body":"test2","msgtype":"m.text"}' + channel = self.make_request("PUT", path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + +class RoomInitialSyncTestCase(RoomBase): + """Tests /rooms/$room_id/initialSync.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + # create the room + self.room_id = self.helper.create_room_as(self.user_id) + + def test_initial_sync(self): + channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) + self.assertEquals(200, channel.code) + + self.assertEquals(self.room_id, channel.json_body["room_id"]) + self.assertEquals("join", channel.json_body["membership"]) + + # Room state is easier to assert on if we unpack it into a dict + state = {} + for event in channel.json_body["state"]: + if "state_key" not in event: + continue + t = event["type"] + if t not in state: + state[t] = [] + state[t].append(event) + + self.assertTrue("m.room.create" in state) + + self.assertTrue("messages" in channel.json_body) + self.assertTrue("chunk" in channel.json_body["messages"]) + self.assertTrue("end" in channel.json_body["messages"]) + + self.assertTrue("presence" in channel.json_body) + + presence_by_user = { + e["content"]["user_id"]: e for e in channel.json_body["presence"] + } + self.assertTrue(self.user_id in presence_by_user) + self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) + + +class RoomMessageListTestCase(RoomBase): + """Tests /rooms/$room_id/messages REST events.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_topo_token_is_accepted(self): + token = "t1-0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + ) + self.assertEquals(200, channel.code) + self.assertTrue("start" in channel.json_body) + self.assertEquals(token, channel.json_body["start"]) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("end" in channel.json_body) + + def test_stream_token_is_accepted_for_fwd_pagianation(self): + token = "s0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + ) + self.assertEquals(200, channel.code) + self.assertTrue("start" in channel.json_body) + self.assertEquals(token, channel.json_body["start"]) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("end" in channel.json_body) + + def test_room_messages_purge(self): + store = self.hs.get_datastore() + pagination_handler = self.hs.get_pagination_handler() + + # Send a first message in the room, which will be removed by the purge. + first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] + first_token = self.get_success( + store.get_topological_token_for_event(first_event_id) + ) + first_token_str = self.get_success(first_token.to_string(store)) + + # Send a second message in the room, which won't be removed, and which we'll + # use as the marker to purge events before. + second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] + second_token = self.get_success( + store.get_topological_token_for_event(second_event_id) + ) + second_token_str = self.get_success(second_token.to_string(store)) + + # Send a third event in the room to ensure we don't fall under any edge case + # due to our marker being the latest forward extremity in the room. + self.helper.send(self.room_id, "message 3") + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) + + # Purge every event before the second event. + purge_id = random_string(16) + pagination_handler._purges_by_id[purge_id] = PurgeStatus() + self.get_success( + pagination_handler._purge_history( + purge_id=purge_id, + room_id=self.room_id, + token=second_token_str, + delete_local_events=True, + ) + ) + + # Check that we only get the second message through /message now that the first + # has been purged. + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) + + # Check that we get no event, but also no error, when querying /messages with + # the token that was pointing at the first event, because we don't have it + # anymore. + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % ( + self.room_id, + first_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) + + +class RoomSearchTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + user_id = True + hijack_auth = False + + def prepare(self, reactor, clock, hs): + + # Register the user who does the searching + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + # Register the user who sends the message + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + # Create a room + self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) + + # Invite the other person + self.helper.invite( + room=self.room, + src=self.user_id, + tok=self.access_token, + targ=self.other_user_id, + ) + + # The other user joins + self.helper.join( + room=self.room, user=self.other_user_id, tok=self.other_access_token + ) + + def test_finds_message(self): + """ + The search functionality will search for content in messages if asked to + do so. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": {"keys": ["content.body"], "search_term": "Hi"} + } + }, + ) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # No context was requested, so we should get none. + self.assertEqual(results["results"][0]["context"], {}) + + def test_include_context(self): + """ + When event_context includes include_profile, profile information will be + included in the search response. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": { + "keys": ["content.body"], + "search_term": "Hi", + "event_context": {"include_profile": True}, + } + } + }, + ) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # We should get context info, like the two users, and the display names. + context = results["results"][0]["context"] + self.assertEqual(len(context["profile_info"].keys()), 2) + self.assertEqual( + context["profile_info"][self.other_user_id]["displayname"], "otheruser" + ) + + +class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + self.url = b"/_matrix/client/r0/publicRooms" + + config = self.default_config() + config["allow_public_rooms_without_auth"] = False + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def test_restricted_no_auth(self): + channel = self.make_request("GET", self.url) + self.assertEqual(channel.code, 401, channel.result) + + def test_restricted_auth(self): + self.register_user("user", "pass") + tok = self.login("user", "pass") + + channel = self.make_request("GET", self.url, access_token=tok) + self.assertEqual(channel.code, 200, channel.result) + + +class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): + """Test that we correctly fallback to local filtering if a remote server + doesn't support search. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(federation_client=Mock()) + + def prepare(self, reactor, clock, hs): + self.register_user("user", "pass") + self.token = self.login("user", "pass") + + self.federation_client = hs.get_federation_client() + + def test_simple(self): + "Simple test for searching rooms over federation" + self.federation_client.get_public_rooms.side_effect = ( + lambda *a, **k: defer.succeed({}) + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_called_once_with( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ) + + def test_fallback(self): + "Test that searching public rooms over federation falls back if it gets a 404" + + # The `get_public_rooms` should be called again if the first call fails + # with a 404, when using search filters. + self.federation_client.get_public_rooms.side_effect = ( + HttpResponseException(404, "Not Found", b""), + defer.succeed({}), + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_has_calls( + [ + call( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ), + call( + "testserv", + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ), + ] + ) + + +class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["allow_per_room_profiles"] = False + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + + # Set a profile for the test user + self.displayname = "test user" + data = {"displayname": self.displayname} + request_data = json.dumps(data) + channel = self.make_request( + "PUT", + "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), + request_data, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_per_room_profile_forbidden(self): + data = {"membership": "join", "displayname": "other test user"} + request_data = json.dumps(data) + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" + % (self.room_id, self.user_id), + request_data, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res_displayname = channel.json_body["content"]["displayname"] + self.assertEqual(res_displayname, self.displayname, channel.result) + + +class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): + """Tests that clients can add a "reason" field to membership events and + that they get correctly added to the generated events and propagated. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + def test_join_reason(self): + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/join", + content={"reason": reason}, + access_token=self.second_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_leave_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/leave", + content={"reason": reason}, + access_token=self.second_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_kick_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/kick", + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.second_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_ban_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/ban", + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_unban_reason(self): + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/unban", + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_invite_reason(self): + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/invite", + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_reject_invite_reason(self): + self.helper.invite( + self.room_id, + src=self.creator, + targ=self.second_user_id, + tok=self.creator_tok, + ) + + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/leave", + content={"reason": reason}, + access_token=self.second_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def _check_for_reason(self, reason): + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( + self.room_id, self.second_user_id + ), + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + event_content = channel.json_body + + self.assertEqual(event_content.get("reason"), reason, channel.result) + + +class LabelsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + # Filter that should only catch messages with the label "#fun". + FILTER_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + # Filter that should only catch messages without the label "#fun". + FILTER_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + # Filter that should only catch messages with the label "#work" but without the label + # "#notfun". + FILTER_LABELS_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_context_filter_labels(self): + """Test that we can filter by a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "with right label", events_before[0] + ) + + events_after = channel.json_body["events_before"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with right label", events_after[0] + ) + + def test_context_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "without label", events_before[0] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 2, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + self.assertEqual( + events_after[1]["content"]["body"], "with two wrong labels", events_after[1] + ) + + def test_context_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /context request. + """ + event_id = self._send_labelled_messages_in_room() + + channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 0, [event["content"] for event in events_before] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + + def test_messages_filter_labels(self): + """Test that we can filter by a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), + ) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_messages_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), + ) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 4, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "without label", events[1]) + self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2]) + self.assertEqual( + events[3]["content"]["body"], "with two wrong labels", events[3] + ) + + def test_messages_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /messages request. + """ + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % ( + self.room_id, + self.tok, + token, + json.dumps(self.FILTER_LABELS_NOT_LABELS), + ), + ) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def test_search_filter_labels(self): + """Test that we can filter by a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 2, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with right label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "with right label", + results[1]["result"]["content"]["body"], + ) + + def test_search_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 4, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "without label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "without label", + results[1]["result"]["content"]["body"], + ) + self.assertEqual( + results[2]["result"]["content"]["body"], + "with wrong label", + results[2]["result"]["content"]["body"], + ) + self.assertEqual( + results[3]["result"]["content"]["body"], + "with two wrong labels", + results[3]["result"]["content"]["body"], + ) + + def test_search_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /search request. + """ + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 1, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with wrong label", + results[0]["result"]["content"]["body"], + ) + + def _send_labelled_messages_in_room(self): + """Sends several messages to a room with different labels (or without any) to test + filtering by label. + Returns: + The ID of the event to use if we're testing filtering on /context. + """ + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + # Return this event's ID when we test filtering in /context requests. + event_id = res["event_id"] + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + return event_id + + +class ContextTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + account.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.tok = self.login("user", "password") + self.room_id = self.helper.create_room_as( + self.user_id, tok=self.tok, is_public=False + ) + + self.other_user_id = self.register_user("user2", "password") + self.other_tok = self.login("user2", "password") + + self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) + self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) + + def test_erased_sender(self): + """Test that an erasure request results in the requester's events being hidden + from any new member of the room. + """ + + # Send a bunch of events in the room. + + self.helper.send(self.room_id, "message 1", tok=self.tok) + self.helper.send(self.room_id, "message 2", tok=self.tok) + event_id = self.helper.send(self.room_id, "message 3", tok=self.tok)["event_id"] + self.helper.send(self.room_id, "message 4", tok=self.tok) + self.helper.send(self.room_id, "message 5", tok=self.tok) + + # Check that we can still see the messages before the erasure request. + + channel = self.make_request( + "GET", + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertEqual( + events_before[0].get("content", {}).get("body"), + "message 2", + events_before[0], + ) + self.assertEqual( + events_before[1].get("content", {}).get("body"), + "message 1", + events_before[1], + ) + + self.assertEqual( + channel.json_body["event"].get("content", {}).get("body"), + "message 3", + channel.json_body["event"], + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertEqual( + events_after[0].get("content", {}).get("body"), + "message 4", + events_after[0], + ) + self.assertEqual( + events_after[1].get("content", {}).get("body"), + "message 5", + events_after[1], + ) + + # Deactivate the first account and erase the user's data. + + deactivate_account_handler = self.hs.get_deactivate_account_handler() + self.get_success( + deactivate_account_handler.deactivate_account( + self.user_id, True, create_requester(self.user_id) + ) + ) + + # Invite another user in the room. This is needed because messages will be + # pruned only if the user wasn't a member of the room when the messages were + # sent. + + invited_user_id = self.register_user("user3", "password") + invited_tok = self.login("user3", "password") + + self.helper.invite( + self.room_id, self.other_user_id, invited_user_id, tok=self.other_tok + ) + self.helper.join(self.room_id, invited_user_id, tok=invited_tok) + + # Check that a user that joined the room after the erasure request can't see + # the messages anymore. + + channel = self.make_request( + "GET", + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=invited_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertDictEqual(events_before[0].get("content"), {}, events_before[0]) + self.assertDictEqual(events_before[1].get("content"), {}, events_before[1]) + + self.assertDictEqual( + channel.json_body["event"].get("content"), {}, channel.json_body["event"] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) + self.assertEqual(events_after[1].get("content"), {}, events_after[1]) + + +class RoomAliasListTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + def test_no_aliases(self): + res = self._get_aliases(self.room_owner_tok) + self.assertEqual(res["aliases"], []) + + def test_not_in_room(self): + self.register_user("user", "test") + user_tok = self.login("user", "test") + res = self._get_aliases(user_tok, expected_code=403) + self.assertEqual(res["errcode"], "M_FORBIDDEN") + + def test_admin_user(self): + alias1 = self._random_alias() + self._set_alias_via_directory(alias1) + + self.register_user("user", "test", admin=True) + user_tok = self.login("user", "test") + + res = self._get_aliases(user_tok) + self.assertEqual(res["aliases"], [alias1]) + + def test_with_aliases(self): + alias1 = self._random_alias() + alias2 = self._random_alias() + + self._set_alias_via_directory(alias1) + self._set_alias_via_directory(alias2) + + res = self._get_aliases(self.room_owner_tok) + self.assertEqual(set(res["aliases"]), {alias1, alias2}) + + def test_peekable_room(self): + alias1 = self._random_alias() + self._set_alias_via_directory(alias1) + + self.helper.send_state( + self.room_id, + EventTypes.RoomHistoryVisibility, + body={"history_visibility": "world_readable"}, + tok=self.room_owner_tok, + ) + + self.register_user("user", "test") + user_tok = self.login("user", "test") + + res = self._get_aliases(user_tok) + self.assertEqual(res["aliases"], [alias1]) + + def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/aliases" % (self.room_id,), + access_token=access_token, + ) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + if expected_code == 200: + self.assertIsInstance(res["aliases"], list) + return res + + def _random_alias(self) -> str: + return RoomAlias(random_string(5), self.hs.hostname).to_string() + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.assertEqual(channel.code, expected_code, channel.result) + + +class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.alias = "#alias:test" + self._set_alias_via_directory(self.alias) + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.assertEqual(channel.code, expected_code, channel.result) + + def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + channel = self.make_request( + "GET", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + access_token=self.room_owner_tok, + ) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + channel = self.make_request( + "PUT", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + json.dumps(content), + access_token=self.room_owner_tok, + ) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def test_canonical_alias(self): + """Test a basic alias message.""" + # There is no canonical alias to start with. + self._get_canonical_alias(expected_code=404) + + # Create an alias. + self._set_canonical_alias({"alias": self.alias}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + # Now remove the alias. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alt_aliases(self): + """Test a canonical alias message with alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alias_alt_aliases(self): + """Test a canonical alias message with an alias and alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alias and alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_partial_modify(self): + """Test removing only the alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({"alias": self.alias}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + def test_add_alias(self): + """Test removing only the alt_aliases.""" + # Create an additional alias. + second_alias = "#second:test" + self._set_alias_via_directory(second_alias) + + # Add the canonical alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Then add the second alias. + self._set_canonical_alias( + {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual( + res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + def test_bad_data(self): + """Invalid data for alt_aliases should cause errors.""" + self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": None}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 0}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 1}, expected_code=400) + self._set_canonical_alias({"alt_aliases": False}, expected_code=400) + self._set_canonical_alias({"alt_aliases": True}, expected_code=400) + self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) + + def test_bad_alias(self): + """An alias which does not point to the room raises a SynapseError.""" + self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py new file mode 100644 index 0000000000..6db7062a8e --- /dev/null +++ b/tests/rest/client/test_sendtodevice.py @@ -0,0 +1,200 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest import admin +from synapse.rest.client import login, sendtodevice, sync + +from tests.unittest import HomeserverTestCase, override_config + + +class SendToDeviceTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + sendtodevice.register_servlets, + sync.register_servlets, + ] + + def test_user_to_user(self): + """A to-device message from one user to another should get delivered""" + + user1 = self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # send the message + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.test/1234", + content={"messages": {user2: {"d2": test_msg}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # check it appears + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + expected_result = { + "events": [ + { + "sender": user1, + "type": "m.test", + "content": test_msg, + } + ] + } + self.assertEqual(channel.json_body["to_device"], expected_result) + + # it should re-appear if we do another sync + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["to_device"], expected_result) + + # it should *not* appear if we do an incremental sync + sync_token = channel.json_body["next_batch"] + channel = self.make_request( + "GET", f"/sync?since={sync_token}", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) + + @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) + def test_local_room_key_request(self): + """m.room_key_request has special-casing; test from local user""" + user1 = self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # send three messages + for i in range(3): + chan = self.make_request( + "PUT", + f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}", + content={"messages": {user2: {"d2": {"idx": i}}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # now sync: we should get two of the three + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 2) + for i in range(2): + self.assertEqual( + msgs[i], + {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}}, + ) + sync_token = channel.json_body["next_batch"] + + # ... time passes + self.reactor.advance(1) + + # and we can send more messages + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/3", + content={"messages": {user2: {"d2": {"idx": 3}}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # ... which should arrive + channel = self.make_request( + "GET", f"/sync?since={sync_token}", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 1) + self.assertEqual( + msgs[0], + {"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}}, + ) + + @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) + def test_remote_room_key_request(self): + """m.room_key_request has special-casing; test from remote user""" + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + federation_registry = self.hs.get_federation_registry() + + # send three messages + for i in range(3): + self.get_success( + federation_registry.on_edu( + "m.direct_to_device", + "remote_server", + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "messages": {user2: {"d2": {"idx": i}}}, + "message_id": f"{i}", + }, + ) + ) + + # now sync: we should get two of the three + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 2) + for i in range(2): + self.assertEqual( + msgs[i], + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "content": {"idx": i}, + }, + ) + sync_token = channel.json_body["next_batch"] + + # ... time passes + self.reactor.advance(1) + + # and we can send more messages + self.get_success( + federation_registry.on_edu( + "m.direct_to_device", + "remote_server", + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "messages": {user2: {"d2": {"idx": 3}}}, + "message_id": "3", + }, + ) + ) + + # ... which should arrive + channel = self.make_request( + "GET", f"/sync?since={sync_token}", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 1) + self.assertEqual( + msgs[0], + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "content": {"idx": 3}, + }, + ) diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py new file mode 100644 index 0000000000..283eccd53f --- /dev/null +++ b/tests/rest/client/test_shared_rooms.py @@ -0,0 +1,142 @@ +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse.rest.admin +from synapse.rest.client import login, room, shared_rooms + +from tests import unittest +from tests.server import FakeChannel + + +class UserSharedRoomsTest(unittest.HomeserverTestCase): + """ + Tests the UserSharedRoomsServlet. + """ + + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + shared_rooms.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["update_user_directory"] = True + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + def _get_shared_rooms(self, token, other_user) -> FakeChannel: + return self.make_request( + "GET", + "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" + % other_user, + access_token=token, + ) + + def test_shared_room_list_public(self): + """ + A room should show up in the shared list of rooms between two users + if it is public. + """ + self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) + + def test_shared_room_list_private(self): + """ + A room should show up in the shared list of rooms between two users + if it is private. + """ + self._check_shared_rooms_with( + room_one_is_public=False, room_two_is_public=False + ) + + def test_shared_room_list_mixed(self): + """ + The shared room list between two users should contain both public and private + rooms. + """ + self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) + + def _check_shared_rooms_with( + self, room_one_is_public: bool, room_two_is_public: bool + ): + """Checks that shared public or private rooms between two users appear in + their shared room lists + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # Create a room. user1 invites user2, who joins + room_id_one = self.helper.create_room_as( + u1, is_public=room_one_is_public, tok=u1_token + ) + self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_one, user=u2, tok=u2_token) + + # Check shared rooms from user1's perspective. + # We should see the one room in common + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room_id_one) + + # Create another room and invite user2 to it + room_id_two = self.helper.create_room_as( + u1, is_public=room_two_is_public, tok=u1_token + ) + self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_two, user=u2, tok=u2_token) + + # Check shared rooms again. We should now see both rooms. + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 2) + for room_id_id in channel.json_body["joined"]: + self.assertIn(room_id_id, [room_id_one, room_id_two]) + + def test_shared_room_list_after_leave(self): + """ + A room should no longer be considered shared if the other + user has left it. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Assert user directory is not empty + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room) + + self.helper.leave(room, user=u1, tok=u1_token) + + # Check user1's view of shared rooms with user2 + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 0) + + # Check user2's view of shared rooms with user1 + channel = self._get_shared_rooms(u2_token, u1) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py new file mode 100644 index 0000000000..95be369d4b --- /dev/null +++ b/tests/rest/client/test_sync.py @@ -0,0 +1,686 @@ +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +import synapse.rest.admin +from synapse.api.constants import ( + EventContentFields, + EventTypes, + ReadReceiptEventFields, + RelationTypes, +) +from synapse.rest.client import knock, login, read_marker, receipts, room, sync + +from tests import unittest +from tests.federation.transport.test_knocking import ( + KnockingStrippedStateEventHelperMixin, +) +from tests.server import TimedOutException +from tests.unittest import override_config + + +class FilterTestCase(unittest.HomeserverTestCase): + + user_id = "@apple:test" + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def test_sync_argless(self): + channel = self.make_request("GET", "/sync") + + self.assertEqual(channel.code, 200) + self.assertIn("next_batch", channel.json_body) + + +class SyncFilterTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def test_sync_filter_labels(self): + """Test that we can filter by a label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_sync_filter_not_labels(self): + """Test that we can filter by the absence of a label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 3, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) + self.assertEqual( + events[2]["content"]["body"], "with two wrong labels", events[2] + ) + + def test_sync_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def _test_sync_filter_labels(self, sync_filter): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + room_id = self.helper.create_room_as(user_id, tok=tok) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + channel = self.make_request( + "GET", "/sync?filter=%s" % sync_filter, access_token=tok + ) + self.assertEqual(channel.code, 200, channel.result) + + return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] + + +class SyncTypingTests(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + user_id = True + hijack_auth = False + + def test_sync_backwards_typing(self): + """ + If the typing serial goes backwards and the typing handler is then reset + (such as when the master restarts and sets the typing serial to 0), we + do not incorrectly return typing information that had a serial greater + than the now-reset serial. + """ + typing_url = "/rooms/%s/typing/%s?access_token=%s" + sync_url = "/sync?timeout=3000000&access_token=%s&since=%s" + + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # Invite the other person + self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # The other user sends some messages + self.helper.send(room, body="Hi!", tok=other_access_token) + self.helper.send(room, body="There!", tok=other_access_token) + + # Start typing. + channel = self.make_request( + "PUT", + typing_url % (room, other_user_id, other_access_token), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,)) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # Stop typing. + channel = self.make_request( + "PUT", + typing_url % (room, other_user_id, other_access_token), + b'{"typing": false}', + ) + self.assertEquals(200, channel.code) + + # Start typing. + channel = self.make_request( + "PUT", + typing_url % (room, other_user_id, other_access_token), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + # Should return immediately + channel = self.make_request("GET", sync_url % (access_token, next_batch)) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # Reset typing serial back to 0, as if the master had. + typing = self.hs.get_typing_handler() + typing._latest_room_serial = 0 + + # Since it checks the state token, we need some state to update to + # invalidate the stream token. + self.helper.send(room, body="There!", tok=other_access_token) + + channel = self.make_request("GET", sync_url % (access_token, next_batch)) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # This should time out! But it does not, because our stream token is + # ahead, and therefore it's saying the typing (that we've actually + # already seen) is new, since it's got a token above our new, now-reset + # stream token. + channel = self.make_request("GET", sync_url % (access_token, next_batch)) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # Clear the typing information, so that it doesn't think everything is + # in the future. + typing._reset() + + # Now it SHOULD fail as it never completes! + with self.assertRaises(TimedOutException): + self.make_request("GET", sync_url % (access_token, next_batch)) + + +class SyncKnockTestCase( + unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin +): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + knock.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to create the room to knock on). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll knock on. + self.room_id = self.helper.create_room_as( + self.user_id, + is_public=False, + room_version="7", + tok=self.tok, + ) + + # Register the second user (used to knock on the room). + self.knocker = self.register_user("knocker", "monkey") + self.knocker_tok = self.login("knocker", "monkey") + + # Perform an initial sync for the knocking user. + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] + + # Set up some room state to test with. + self.expected_room_state = self.send_example_state_events_to_room( + hs, self.room_id, self.user_id + ) + + @override_config({"experimental_features": {"msc2403_enabled": True}}) + def test_knock_room_state(self): + """Tests that /sync returns state from a room after knocking on it.""" + # Knock on a room + channel = self.make_request( + "POST", + "/_matrix/client/r0/knock/%s" % (self.room_id,), + b"{}", + self.knocker_tok, + ) + self.assertEquals(200, channel.code, channel.result) + + # We expect to see the knock event in the stripped room state later + self.expected_room_state[EventTypes.Member] = { + "content": {"membership": "knock", "displayname": "knocker"}, + "state_key": "@knocker:test", + } + + # Check that /sync includes stripped state from the room + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.knocker_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Extract the stripped room state events from /sync + knock_entry = channel.json_body["rooms"]["knock"] + room_state_events = knock_entry[self.room_id]["knock_state"]["events"] + + # Validate that the knock membership event came last + self.assertEqual(room_state_events[-1]["type"], EventTypes.Member) + + # Validate the stripped room state events + self.check_knock_room_state_against_room_state( + room_state_events, self.expected_room_state + ) + + +class ReadReceiptsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + receipts.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Join the second user + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + + @override_config({"experimental_features": {"msc2285_enabled": True}}) + def test_hidden_read_receipts(self): + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt to tell the server the first user's message was read + body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + body, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that the first user can't see the other user's hidden read receipt + self.assertEqual(self._get_read_receipt(), None) + + def test_read_receipt_with_empty_body(self): + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt for this message with an empty body + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + def _get_read_receipt(self): + """Syncs and returns the read receipt.""" + + # Checks if event is a read receipt + def is_read_receipt(event): + return event["type"] == "m.receipt" + + # Sync + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] + + # Return the read receipt + ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ + "ephemeral" + ]["events"] + return next(filter(is_read_receipt, ephemeral_events), None) + + +class UnreadMessagesTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + read_marker.register_servlets, + room.register_servlets, + sync.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to check the unread counts). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll check unread counts for. + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user (used to send events to the room). + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Change the power levels of the room so that the second user can send state + # events. + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + { + "users": {self.user_id: 100, self.user2: 100}, + "users_default": 0, + "events": { + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.history_visibility": 100, + "m.room.canonical_alias": 50, + "m.room.avatar": 50, + "m.room.tombstone": 100, + "m.room.server_acl": 100, + "m.room.encryption": 100, + }, + "events_default": 0, + "state_default": 50, + "ban": 50, + "kick": 50, + "redact": 50, + "invite": 0, + }, + tok=self.tok, + ) + + def test_unread_counts(self): + """Tests that /sync returns the right value for the unread count (MSC2654).""" + + # Check that our own messages don't increase the unread count. + self.helper.send(self.room_id, "hello", tok=self.tok) + self._check_unread_count(0) + + # Join the new user and check that this doesn't increase the unread count. + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + self._check_unread_count(0) + + # Check that the new user sending a message increases our unread count. + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + channel = self.make_request( + "POST", + "/rooms/%s/read_markers" % self.room_id, + body, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that hidden read receipts don't break unread counts + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + body, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that room name changes increase the unread counter. + self.helper.send_state( + self.room_id, + "m.room.name", + {"name": "my super room"}, + tok=self.tok2, + ) + self._check_unread_count(1) + + # Check that room topic changes increase the unread counter. + self.helper.send_state( + self.room_id, + "m.room.topic", + {"topic": "welcome!!!"}, + tok=self.tok2, + ) + self._check_unread_count(2) + + # Check that encrypted messages increase the unread counter. + self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) + self._check_unread_count(3) + + # Check that custom events with a body increase the unread counter. + self.helper.send_event( + self.room_id, + "org.matrix.custom_type", + {"body": "hello"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that edits don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": "hello", + "msgtype": "m.text", + "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + }, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that notices don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"body": "hello", "msgtype": "m.notice"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that tombstone events changes increase the unread counter. + self.helper.send_state( + self.room_id, + EventTypes.Tombstone, + {"replacement_room": "!someroom:test"}, + tok=self.tok2, + ) + self._check_unread_count(5) + + def _check_unread_count(self, expected_count: int): + """Syncs and compares the unread count with the expected value.""" + + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.tok, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + + room_entry = channel.json_body["rooms"]["join"][self.room_id] + self.assertEqual( + room_entry["org.matrix.msc2654.unread_count"], + expected_count, + room_entry, + ) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] + + +class SyncCacheTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def test_noop_sync_does_not_tightloop(self): + """If the sync times out, we shouldn't cache the result + + Essentially a regression test for #8518. + """ + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # we should immediately get an initial sync response + channel = self.make_request("GET", "/sync", access_token=self.tok) + self.assertEqual(channel.code, 200, channel.json_body) + + # now, make an incremental sync request, with a timeout + next_batch = channel.json_body["next_batch"] + channel = self.make_request( + "GET", + f"/sync?since={next_batch}&timeout=10000", + access_token=self.tok, + await_result=False, + ) + # that should block for 10 seconds + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=9900) + channel.await_result(timeout_ms=200) + self.assertEqual(channel.code, 200, channel.json_body) + + # we expect the next_batch in the result to be the same as before + self.assertEqual(channel.json_body["next_batch"], next_batch) + + # another incremental sync should also block. + channel = self.make_request( + "GET", + f"/sync?since={next_batch}&timeout=10000", + access_token=self.tok, + await_result=False, + ) + # that should block for 10 seconds + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=9900) + channel.await_result(timeout_ms=200) + self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py new file mode 100644 index 0000000000..b54b004733 --- /dev/null +++ b/tests/rest/client/test_typing.py @@ -0,0 +1,121 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests REST events for /rooms paths.""" + +from unittest.mock import Mock + +from synapse.rest.client import room +from synapse.types import UserID + +from tests import unittest + +PATH_PREFIX = "/_matrix/client/api/v1" + + +class RoomTypingTestCase(unittest.HomeserverTestCase): + """Tests /rooms/$room_id/typing/$user_id REST API.""" + + user_id = "@sid:red" + + user = UserID.from_string(user_id) + servlets = [room.register_servlets] + + def make_homeserver(self, reactor, clock): + + hs = self.setup_test_homeserver( + "red", + federation_http_client=None, + federation_client=Mock(), + ) + + self.event_source = hs.get_event_sources().sources["typing"] + + hs.get_federation_handler = Mock() + + async def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.auth_user_id), + "token_id": 1, + "is_guest": False, + } + + hs.get_auth().get_user_by_access_token = get_user_by_access_token + + async def _insert_client_ip(*args, **kwargs): + return None + + hs.get_datastore().insert_client_ip = _insert_client_ip + + return hs + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + # Need another user to make notifications actually work + self.helper.join(self.room_id, user="@jim:red") + + def test_set_typing(self): + channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + self.assertEquals(self.event_source.get_current_key(), 1) + events = self.get_success( + self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + ) + self.assertEquals( + events[0], + [ + { + "type": "m.typing", + "room_id": self.room_id, + "content": {"user_ids": [self.user_id]}, + } + ], + ) + + def test_set_not_typing(self): + channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + b'{"typing": false}', + ) + self.assertEquals(200, channel.code) + + def test_typing_timeout(self): + channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + self.assertEquals(self.event_source.get_current_key(), 1) + + self.reactor.advance(36) + + self.assertEquals(self.event_source.get_current_key(), 2) + + channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + self.assertEquals(self.event_source.get_current_key(), 3) diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py new file mode 100644 index 0000000000..72f976d8e2 --- /dev/null +++ b/tests/rest/client/test_upgrade_room.py @@ -0,0 +1,159 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.rest import admin +from synapse.rest.client import login, room, room_upgrade_rest_servlet + +from tests import unittest +from tests.server import FakeChannel + + +class UpgradeRoomTest(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + room_upgrade_rest_servlet.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + self.creator = self.register_user("creator", "pass") + self.creator_token = self.login(self.creator, "pass") + + self.other = self.register_user("user", "pass") + self.other_token = self.login(self.other, "pass") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token) + self.helper.join(self.room_id, self.other, tok=self.other_token) + + def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel: + # We never want a cached response. + self.reactor.advance(5 * 60 + 1) + + return self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id, + # This will upgrade a room to the same version, but that's fine. + content={"new_version": DEFAULT_ROOM_VERSION}, + access_token=token or self.creator_token, + ) + + def test_upgrade(self): + """ + Upgrading a room should work fine. + """ + channel = self._upgrade_room() + self.assertEquals(200, channel.code, channel.result) + self.assertIn("replacement_room", channel.json_body) + + def test_not_in_room(self): + """ + Upgrading a room should work fine. + """ + # THe user isn't in the room. + roomless = self.register_user("roomless", "pass") + roomless_token = self.login(roomless, "pass") + + channel = self._upgrade_room(roomless_token) + self.assertEquals(403, channel.code, channel.result) + + def test_power_levels(self): + """ + Another user can upgrade the room if their power level is increased. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["users"][self.other] = 100 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + def test_power_levels_user_default(self): + """ + Another user can upgrade the room if the default power level for users is increased. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["users_default"] = 100 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + def test_power_levels_tombstone(self): + """ + Another user can upgrade the room if they can send the tombstone event. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["events"]["m.room.tombstone"] = 0 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + self.assertNotIn(self.other, power_levels["users"]) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py new file mode 100644 index 0000000000..954ad1a1fd --- /dev/null +++ b/tests/rest/client/utils.py @@ -0,0 +1,654 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +import time +import urllib.parse +from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union +from unittest.mock import patch + +import attr + +from twisted.web.resource import Resource +from twisted.web.server import Site + +from synapse.api.constants import Membership +from synapse.types import JsonDict + +from tests.server import FakeChannel, FakeSite, make_request +from tests.test_utils import FakeResponse +from tests.test_utils.html_parsers import TestHtmlParser + + +@attr.s +class RestHelper: + """Contains extra helper functions to quickly and clearly perform a given + REST action, which isn't the focus of the test. + """ + + hs = attr.ib() + site = attr.ib(type=Site) + auth_user_id = attr.ib() + + def create_room_as( + self, + room_creator: Optional[str] = None, + is_public: bool = True, + room_version: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = 200, + extra_content: Optional[Dict] = None, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ) -> str: + """ + Create a room. + + Args: + room_creator: The user ID to create the room with. + is_public: If True, the `visibility` parameter will be set to the + default (public). Otherwise, the `visibility` parameter will be set + to "private". + room_version: The room version to create the room as. Defaults to Synapse's + default room version. + tok: The access token to use in the request. + expect_code: The expected HTTP response code. + + Returns: + The ID of the newly created room. + """ + temp_id = self.auth_user_id + self.auth_user_id = room_creator + path = "/_matrix/client/r0/createRoom" + content = extra_content or {} + if not is_public: + content["visibility"] = "private" + if room_version: + content["room_version"] = room_version + if tok: + path = path + "?access_token=%s" % tok + + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + path, + json.dumps(content).encode("utf8"), + custom_headers=custom_headers, + ) + + assert channel.result["code"] == b"%d" % expect_code, channel.result + self.auth_user_id = temp_id + + if expect_code == 200: + return channel.json_body["room_id"] + + def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=src, + targ=targ, + tok=tok, + membership=Membership.INVITE, + expect_code=expect_code, + ) + + def join(self, room=None, user=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.JOIN, + expect_code=expect_code, + ) + + def leave(self, room=None, user=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.LEAVE, + expect_code=expect_code, + ) + + def change_membership( + self, + room: str, + src: str, + targ: str, + membership: str, + extra_data: Optional[dict] = None, + tok: Optional[str] = None, + expect_code: int = 200, + ) -> None: + """ + Send a membership state event into a room. + + Args: + room: The ID of the room to send to + src: The mxid of the event sender + targ: The mxid of the event's target. The state key + membership: The type of membership event + extra_data: Extra information to include in the content of the event + tok: The user access token to use + expect_code: The expected HTTP response code + """ + temp_id = self.auth_user_id + self.auth_user_id = src + + path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ) + if tok: + path = path + "?access_token=%s" % tok + + data = {"membership": membership} + data.update(extra_data or {}) + + channel = make_request( + self.hs.get_reactor(), + self.site, + "PUT", + path, + json.dumps(data).encode("utf8"), + ) + + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + self.auth_user_id = temp_id + + def send( + self, + room_id, + body=None, + txn_id=None, + tok=None, + expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): + if body is None: + body = "body_text_here" + + content = {"msgtype": "m.text", "body": body} + + return self.send_event( + room_id, + "m.room.message", + content, + txn_id, + tok, + expect_code, + custom_headers=custom_headers, + ) + + def send_event( + self, + room_id, + type, + content: Optional[dict] = None, + txn_id=None, + tok=None, + expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + + path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id) + if tok: + path = path + "?access_token=%s" % tok + + channel = make_request( + self.hs.get_reactor(), + self.site, + "PUT", + path, + json.dumps(content or {}).encode("utf8"), + custom_headers=custom_headers, + ) + + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + return channel.json_body + + def _read_write_state( + self, + room_id: str, + event_type: str, + body: Optional[Dict[str, Any]], + tok: str, + expect_code: int = 200, + state_key: str = "", + method: str = "GET", + ) -> Dict: + """Read or write some state from a given room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + If None, the request to the server will have an empty body + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + method: "GET" or "PUT" for reading or writing state, respectively + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( + room_id, + event_type, + state_key, + ) + if tok: + path = path + "?access_token=%s" % tok + + # Set request body if provided + content = b"" + if body is not None: + content = json.dumps(body).encode("utf8") + + channel = make_request(self.hs.get_reactor(), self.site, method, path, content) + + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + return channel.json_body + + def get_state( + self, + room_id: str, + event_type: str, + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Gets some state from a room + + Args: + room_id: + event_type: The type of state event + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, None, tok, expect_code, state_key, method="GET" + ) + + def send_state( + self, + room_id: str, + event_type: str, + body: Dict[str, Any], + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Set some state in a room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, body, tok, expect_code, state_key, method="PUT" + ) + + def upload_media( + self, + resource: Resource, + image_data: bytes, + tok: str, + filename: str = "test.png", + expect_code: int = 200, + ) -> dict: + """Upload a piece of test media to the media repo + Args: + resource: The resource that will handle the upload request + image_data: The image data to upload + tok: The user token to use during the upload + filename: The filename of the media to be uploaded + expect_code: The return code to expect from attempting to upload the media + """ + image_length = len(image_data) + path = "/_matrix/media/r0/upload?filename=%s" % (filename,) + channel = make_request( + self.hs.get_reactor(), + FakeSite(resource), + "POST", + path, + content=image_data, + access_token=tok, + custom_headers=[(b"Content-Length", str(image_length))], + ) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + return channel.json_body + + def login_via_oidc(self, remote_user_id: str) -> JsonDict: + """Log in (as a new user) via OIDC + + Returns the result of the final token login. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + """ + client_redirect_url = "https://x" + channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) + + # expect a confirmation page + assert channel.code == 200, channel.result + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.text_body, + ) + assert m, channel.text_body + login_token = m.group(1) + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token and device id. + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + assert channel.code == 200 + return channel.json_body + + def auth_via_oidc( + self, + user_info_dict: JsonDict, + client_redirect_url: Optional[str] = None, + ui_auth_session_id: Optional[str] = None, + ) -> FakeChannel: + """Perform an OIDC authentication flow via a mock OIDC provider. + + This can be used for either login or user-interactive auth. + + Starts by making a request to the relevant synapse redirect endpoint, which is + expected to serve a 302 to the OIDC provider. We then make a request to the + OIDC callback endpoint, intercepting the HTTP requests that will get sent back + to the OIDC provider. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + + Args: + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": ""}'. + client_redirect_url: for a login flow, the client redirect URL to pass to + the login redirect endpoint + ui_auth_session_id: if set, we will perform a UI Auth flow. The session id + of the UI auth. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + Note that the response code may be a 200, 302 or 400 depending on how things + went. + """ + + cookies = {} + + # if we're doing a ui auth, hit the ui auth redirect endpoint + if ui_auth_session_id: + # can't set the client redirect url for UI Auth + assert client_redirect_url is None + oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + else: + # otherwise, hit the login redirect endpoint + oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + + # we now have a URI for the OIDC IdP, but we skip that and go straight + # back to synapse's OIDC callback resource. However, we do need the "state" + # param that synapse passes to the IdP via query params, as well as the cookie + # that synapse passes to the client. + + oauth_uri_path, _ = oauth_uri.split("?", 1) + assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( + "unexpected SSO URI " + oauth_uri_path + ) + return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + + def complete_oidc_auth( + self, + oauth_uri: str, + cookies: Mapping[str, str], + user_info_dict: JsonDict, + ) -> FakeChannel: + """Mock out an OIDC authentication flow + + Assumes that an OIDC auth has been initiated by one of initiate_sso_login or + initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to + Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get + sent back to the OIDC provider. + + Requires the OIDC callback resource to be mounted at the normal place. + + Args: + oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, + from initiate_sso_login or initiate_sso_ui_auth). + cookies: the cookies set by synapse's redirect endpoint, which will be + sent back to the callback endpoint. + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": ""}'. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + """ + _, oauth_uri_qs = oauth_uri.split("?", 1) + params = urllib.parse.parse_qs(oauth_uri_qs) + callback_uri = "%s?%s" % ( + urllib.parse.urlparse(params["redirect_uri"][0]).path, + urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), + ) + + # before we hit the callback uri, stub out some methods in the http client so + # that we don't have to handle full HTTPS requests. + # (expected url, json response) pairs, in the order we expect them. + expected_requests = [ + # first we get a hit to the token endpoint, which we tell to return + # a dummy OIDC access token + (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), + # and then one to the user_info endpoint, which returns our remote user id. + (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), + ] + + async def mock_req(method: str, uri: str, data=None, headers=None): + (expected_uri, resp_obj) = expected_requests.pop(0) + assert uri == expected_uri + resp = FakeResponse( + code=200, + phrase=b"OK", + body=json.dumps(resp_obj).encode("utf-8"), + ) + return resp + + with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): + # now hit the callback URI with the right params and a made-up code + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + callback_uri, + custom_headers=[ + ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() + ], + ) + return channel + + def initiate_sso_login( + self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] + ) -> str: + """Make a request to the login-via-sso redirect endpoint, and return the target + + Assumes that exactly one SSO provider has been configured. Requires the login + servlet to be mounted. + + Args: + client_redirect_url: the client redirect URL to pass to the login redirect + endpoint + cookies: any cookies returned will be added to this dict + + Returns: + the URI that the client gets redirected to (ie, the SSO server) + """ + params = {} + if client_redirect_url: + params["redirectUrl"] = client_redirect_url + + # hit the redirect url (which should redirect back to the redirect url. This + # is the easiest way of figuring out what the Host header ought to be set to + # to keep Synapse happy. + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), + ) + assert channel.code == 302 + + # hit the redirect url again with the right Host header, which should now issue + # a cookie and redirect to the SSO provider. + location = channel.headers.getRawHeaders("Location")[0] + parts = urllib.parse.urlsplit(location) + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + urllib.parse.urlunsplit(("", "") + parts[2:]), + custom_headers=[ + ("Host", parts[1]), + ], + ) + + assert channel.code == 302 + channel.extract_cookies(cookies) + return channel.headers.getRawHeaders("Location")[0] + + def initiate_sso_ui_auth( + self, ui_auth_session_id: str, cookies: MutableMapping[str, str] + ) -> str: + """Make a request to the ui-auth-via-sso endpoint, and return the target + + Assumes that exactly one SSO provider has been configured. Requires the + AuthRestServlet to be mounted. + + Args: + ui_auth_session_id: the session id of the UI auth + cookies: any cookies returned will be added to this dict + + Returns: + the URI that the client gets linked to (ie, the SSO server) + """ + sso_redirect_endpoint = ( + "/_matrix/client/r0/auth/m.login.sso/fallback/web?" + + urllib.parse.urlencode({"session": ui_auth_session_id}) + ) + # hit the redirect url (which will issue a cookie and state) + channel = make_request( + self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint + ) + # that should serve a confirmation page + assert channel.code == 200, channel.text_body + channel.extract_cookies(cookies) + + # parse the confirmation page to fish out the link. + p = TestHtmlParser() + p.feed(channel.text_body) + p.close() + assert len(p.links) == 1, "not exactly one link in confirmation page" + oauth_uri = p.links[0] + return oauth_uri + + +# an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" +TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" +TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" +TEST_OIDC_CONFIG = { + "enabled": True, + "discover": False, + "issuer": "https://issuer.test", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, + "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, + "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, + "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, +} diff --git a/tests/rest/client/v1/__init__.py b/tests/rest/client/v1/__init__.py deleted file mode 100644 index 5e83dba2ed..0000000000 --- a/tests/rest/client/v1/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py deleted file mode 100644 index d2181ea907..0000000000 --- a/tests/rest/client/v1/test_directory.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from synapse.rest import admin -from synapse.rest.client import directory, login, room -from synapse.types import RoomAlias -from synapse.util.stringutils import random_string - -from tests import unittest -from tests.unittest import override_config - - -class DirectoryTestCase(unittest.HomeserverTestCase): - - servlets = [ - admin.register_servlets_for_client_rest_resource, - directory.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["require_membership_for_aliases"] = True - - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def prepare(self, reactor, clock, homeserver): - self.room_owner = self.register_user("room_owner", "test") - self.room_owner_tok = self.login("room_owner", "test") - - self.room_id = self.helper.create_room_as( - self.room_owner, tok=self.room_owner_tok - ) - - self.user = self.register_user("user", "test") - self.user_tok = self.login("user", "test") - - def test_state_event_not_in_room(self): - self.ensure_user_left_room() - self.set_alias_via_state_event(403) - - def test_directory_endpoint_not_in_room(self): - self.ensure_user_left_room() - self.set_alias_via_directory(403) - - def test_state_event_in_room_too_long(self): - self.ensure_user_joined_room() - self.set_alias_via_state_event(400, alias_length=256) - - def test_directory_in_room_too_long(self): - self.ensure_user_joined_room() - self.set_alias_via_directory(400, alias_length=256) - - @override_config({"default_room_version": 5}) - def test_state_event_user_in_v5_room(self): - """Test that a regular user can add alias events before room v6""" - self.ensure_user_joined_room() - self.set_alias_via_state_event(200) - - @override_config({"default_room_version": 6}) - def test_state_event_v6_room(self): - """Test that a regular user can *not* add alias events from room v6""" - self.ensure_user_joined_room() - self.set_alias_via_state_event(403) - - def test_directory_in_room(self): - self.ensure_user_joined_room() - self.set_alias_via_directory(200) - - def test_room_creation_too_long(self): - url = "/_matrix/client/r0/createRoom" - - # We use deliberately a localpart under the length threshold so - # that we can make sure that the check is done on the whole alias. - data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} - request_data = json.dumps(data) - channel = self.make_request( - "POST", url, request_data, access_token=self.user_tok - ) - self.assertEqual(channel.code, 400, channel.result) - - def test_room_creation(self): - url = "/_matrix/client/r0/createRoom" - - # Check with an alias of allowed length. There should already be - # a test that ensures it works in test_register.py, but let's be - # as cautious as possible here. - data = {"room_alias_name": random_string(5)} - request_data = json.dumps(data) - channel = self.make_request( - "POST", url, request_data, access_token=self.user_tok - ) - self.assertEqual(channel.code, 200, channel.result) - - def set_alias_via_state_event(self, expected_code, alias_length=5): - url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( - self.room_id, - self.hs.hostname, - ) - - data = {"aliases": [self.random_alias(alias_length)]} - request_data = json.dumps(data) - - channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok - ) - self.assertEqual(channel.code, expected_code, channel.result) - - def set_alias_via_directory(self, expected_code, alias_length=5): - url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) - data = {"room_id": self.room_id} - request_data = json.dumps(data) - - channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok - ) - self.assertEqual(channel.code, expected_code, channel.result) - - def random_alias(self, length): - return RoomAlias(random_string(length), self.hs.hostname).to_string() - - def ensure_user_left_room(self): - self.ensure_membership("leave") - - def ensure_user_joined_room(self): - self.ensure_membership("join") - - def ensure_membership(self, membership): - try: - if membership == "leave": - self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) - if membership == "join": - self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok) - except AssertionError: - # We don't care whether the leave request didn't return a 200 (e.g. - # if the user isn't already in the room), because we only want to - # make sure the user isn't in the room. - pass diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py deleted file mode 100644 index a90294003e..0000000000 --- a/tests/rest/client/v1/test_events.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" Tests REST events for /events paths.""" - -from unittest.mock import Mock - -import synapse.rest.admin -from synapse.rest.client import events, login, room - -from tests import unittest - - -class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): - """Tests event streaming (GET /events).""" - - servlets = [ - events.register_servlets, - room.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - - config = self.default_config() - config["enable_registration_captcha"] = False - config["enable_registration"] = True - config["auto_join_rooms"] = [] - - hs = self.setup_test_homeserver(config=config) - - hs.get_federation_handler = Mock() - - return hs - - def prepare(self, reactor, clock, hs): - - # register an account - self.user_id = self.register_user("sid1", "pass") - self.token = self.login(self.user_id, "pass") - - # register a 2nd account - self.other_user = self.register_user("other2", "pass") - self.other_token = self.login(self.other_user, "pass") - - def test_stream_basic_permissions(self): - # invalid token, expect 401 - # note: this is in violation of the original v1 spec, which expected - # 403. However, since the v1 spec no longer exists and the v1 - # implementation is now part of the r0 implementation, the newer - # behaviour is used instead to be consistent with the r0 spec. - # see issue #2602 - channel = self.make_request( - "GET", "/events?access_token=%s" % ("invalid" + self.token,) - ) - self.assertEquals(channel.code, 401, msg=channel.result) - - # valid token, expect content - channel = self.make_request( - "GET", "/events?access_token=%s&timeout=0" % (self.token,) - ) - self.assertEquals(channel.code, 200, msg=channel.result) - self.assertTrue("chunk" in channel.json_body) - self.assertTrue("start" in channel.json_body) - self.assertTrue("end" in channel.json_body) - - def test_stream_room_permissions(self): - room_id = self.helper.create_room_as(self.other_user, tok=self.other_token) - self.helper.send(room_id, tok=self.other_token) - - # invited to room (expect no content for room) - self.helper.invite( - room_id, src=self.other_user, targ=self.user_id, tok=self.other_token - ) - - # valid token, expect content - channel = self.make_request( - "GET", "/events?access_token=%s&timeout=0" % (self.token,) - ) - self.assertEquals(channel.code, 200, msg=channel.result) - - # We may get a presence event for ourselves down - self.assertEquals( - 0, - len( - [ - c - for c in channel.json_body["chunk"] - if not ( - c.get("type") == "m.presence" - and c["content"].get("user_id") == self.user_id - ) - ] - ), - ) - - # joined room (expect all content for room) - self.helper.join(room=room_id, user=self.user_id, tok=self.token) - - # left to room (expect no content for room) - - def TODO_test_stream_items(self): - # new user, no content - - # join room, expect 1 item (join) - - # send message, expect 2 items (join,send) - - # set topic, expect 3 items (join,send,topic) - - # someone else join room, expect 4 (join,send,topic,join) - - # someone else send message, expect 5 (join,send.topic,join,send) - - # someone else set topic, expect 6 (join,send,topic,join,send,topic) - pass - - -class GetEventsTestCase(unittest.HomeserverTestCase): - servlets = [ - events.register_servlets, - room.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - ] - - def prepare(self, hs, reactor, clock): - - # register an account - self.user_id = self.register_user("sid1", "pass") - self.token = self.login(self.user_id, "pass") - - self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) - - def test_get_event_via_events(self): - resp = self.helper.send(self.room_id, tok=self.token) - event_id = resp["event_id"] - - channel = self.make_request( - "GET", - "/events/" + event_id, - access_token=self.token, - ) - self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py deleted file mode 100644 index eba3552b19..0000000000 --- a/tests/rest/client/v1/test_login.py +++ /dev/null @@ -1,1345 +0,0 @@ -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import urllib.parse -from typing import Any, Dict, List, Optional, Union -from unittest.mock import Mock -from urllib.parse import urlencode - -import pymacaroons - -from twisted.web.resource import Resource - -import synapse.rest.admin -from synapse.appservice import ApplicationService -from synapse.rest.client import devices, login, logout, register -from synapse.rest.client.account import WhoamiRestServlet -from synapse.rest.synapse.client import build_synapse_client_resource_tree -from synapse.types import create_requester - -from tests import unittest -from tests.handlers.test_oidc import HAS_OIDC -from tests.handlers.test_saml import has_saml2 -from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG -from tests.test_utils.html_parsers import TestHtmlParser -from tests.unittest import HomeserverTestCase, override_config, skip_unless - -try: - import jwt - - HAS_JWT = True -except ImportError: - HAS_JWT = False - - -# synapse server name: used to populate public_baseurl in some tests -SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" - -# public_baseurl for some tests. It uses an http:// scheme because -# FakeChannel.isSecure() returns False, so synapse will see the requested uri as -# http://..., so using http in the public_baseurl stops Synapse trying to redirect to -# https://.... -BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) - -# CAS server used in some tests -CAS_SERVER = "https://fake.test" - -# just enough to tell pysaml2 where to redirect to -SAML_SERVER = "https://test.saml.server/idp/sso" -TEST_SAML_METADATA = """ - - - - - -""" % { - "SAML_SERVER": SAML_SERVER, -} - -LOGIN_URL = b"/_matrix/client/r0/login" -TEST_URL = b"/_matrix/client/r0/account/whoami" - -# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + -TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' - -# the query params in TEST_CLIENT_REDIRECT_URL -EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] - -# (possibly experimental) login flows we expect to appear in the list after the normal -# ones -ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] - - -class LoginRestServletTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - logout.register_servlets, - devices.register_servlets, - lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), - ] - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.enable_registration = True - self.hs.config.registrations_require_3pid = [] - self.hs.config.auto_join_rooms = [] - self.hs.config.enable_registration_captcha = False - - return self.hs - - @override_config( - { - "rc_login": { - "address": {"per_second": 0.17, "burst_count": 5}, - # Prevent the account login ratelimiter from raising first - # - # This is normally covered by the default test homeserver config - # which sets these values to 10000, but as we're overriding the entire - # rc_login dict here, we need to set this manually as well - "account": {"per_second": 10000, "burst_count": 10000}, - } - } - ) - def test_POST_ratelimiting_per_address(self): - # Create different users so we're sure not to be bothered by the per-user - # ratelimiter. - for i in range(0, 6): - self.register_user("kermit" + str(i), "monkey") - - for i in range(0, 6): - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower - # than 1min. - self.assertTrue(retry_after_ms < 6000) - - self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - @override_config( - { - "rc_login": { - "account": {"per_second": 0.17, "burst_count": 5}, - # Prevent the address login ratelimiter from raising first - # - # This is normally covered by the default test homeserver config - # which sets these values to 10000, but as we're overriding the entire - # rc_login dict here, we need to set this manually as well - "address": {"per_second": 10000, "burst_count": 10000}, - } - } - ) - def test_POST_ratelimiting_per_account(self): - self.register_user("kermit", "monkey") - - for i in range(0, 6): - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower - # than 1min. - self.assertTrue(retry_after_ms < 6000) - - self.reactor.advance(retry_after_ms / 1000.0) - - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - @override_config( - { - "rc_login": { - # Prevent the address login ratelimiter from raising first - # - # This is normally covered by the default test homeserver config - # which sets these values to 10000, but as we're overriding the entire - # rc_login dict here, we need to set this manually as well - "address": {"per_second": 10000, "burst_count": 10000}, - "failed_attempts": {"per_second": 0.17, "burst_count": 5}, - } - } - ) - def test_POST_ratelimiting_per_account_failed_attempts(self): - self.register_user("kermit", "monkey") - - for i in range(0, 6): - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "notamonkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"403", channel.result) - - # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower - # than 1min. - self.assertTrue(retry_after_ms < 6000) - - self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "notamonkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.result["code"], b"403", channel.result) - - @override_config({"session_lifetime": "24h"}) - def test_soft_logout(self): - self.register_user("kermit", "monkey") - - # we shouldn't be able to make requests without an access token - channel = self.make_request(b"GET", TEST_URL) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") - - # log in as normal - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.code, 200, channel.result) - access_token = channel.json_body["access_token"] - device_id = channel.json_body["device_id"] - - # we should now be able to make requests with the access token - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) - - # time passes - self.reactor.advance(24 * 3600) - - # ... and we should be soft-logouted - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) - - # - # test behaviour after deleting the expired device - # - - # we now log in as a different device - access_token_2 = self.login("kermit", "monkey") - - # more requests with the expired token should still return a soft-logout - self.reactor.advance(3600) - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) - - # ... but if we delete that device, it will be a proper logout - self._delete_device(access_token_2, "kermit", "monkey", device_id) - - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], False) - - def _delete_device(self, access_token, user_id, password, device_id): - """Perform the UI-Auth to delete a device""" - channel = self.make_request( - b"DELETE", "devices/" + device_id, access_token=access_token - ) - self.assertEquals(channel.code, 401, channel.result) - # check it's a UI-Auth fail - self.assertEqual( - set(channel.json_body.keys()), - {"flows", "params", "session"}, - channel.result, - ) - - auth = { - "type": "m.login.password", - # https://github.com/matrix-org/synapse/issues/5665 - # "identifier": {"type": "m.id.user", "user": user_id}, - "user": user_id, - "password": password, - "session": channel.json_body["session"], - } - - channel = self.make_request( - b"DELETE", - "devices/" + device_id, - access_token=access_token, - content={"auth": auth}, - ) - self.assertEquals(channel.code, 200, channel.result) - - @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_after_being_soft_logged_out(self): - self.register_user("kermit", "monkey") - - # log in as normal - access_token = self.login("kermit", "monkey") - - # we should now be able to make requests with the access token - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) - - # time passes - self.reactor.advance(24 * 3600) - - # ... and we should be soft-logouted - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) - - # Now try to hard logout this session - channel = self.make_request(b"POST", "/logout", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) - - @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): - self.register_user("kermit", "monkey") - - # log in as normal - access_token = self.login("kermit", "monkey") - - # we should now be able to make requests with the access token - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) - - # time passes - self.reactor.advance(24 * 3600) - - # ... and we should be soft-logouted - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) - - # Now try to hard log out all of the user's sessions - channel = self.make_request(b"POST", "/logout/all", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) - - -@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") -class MultiSSOTestCase(unittest.HomeserverTestCase): - """Tests for homeservers with multiple SSO providers enabled""" - - servlets = [ - login.register_servlets, - ] - - def default_config(self) -> Dict[str, Any]: - config = super().default_config() - - config["public_baseurl"] = BASE_URL - - config["cas_config"] = { - "enabled": True, - "server_url": CAS_SERVER, - "service_url": "https://matrix.goodserver.com:8448", - } - - config["saml2_config"] = { - "sp_config": { - "metadata": {"inline": [TEST_SAML_METADATA]}, - # use the XMLSecurity backend to avoid relying on xmlsec1 - "crypto_backend": "XMLSecurity", - }, - } - - # default OIDC provider - config["oidc_config"] = TEST_OIDC_CONFIG - - # additional OIDC providers - config["oidc_providers"] = [ - { - "idp_id": "idp1", - "idp_name": "IDP1", - "discover": False, - "issuer": "https://issuer1", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "scopes": ["profile"], - "authorization_endpoint": "https://issuer1/auth", - "token_endpoint": "https://issuer1/token", - "userinfo_endpoint": "https://issuer1/userinfo", - "user_mapping_provider": { - "config": {"localpart_template": "{{ user.sub }}"} - }, - } - ] - return config - - def create_resource_dict(self) -> Dict[str, Resource]: - d = super().create_resource_dict() - d.update(build_synapse_client_resource_tree(self.hs)) - return d - - def test_get_login_flows(self): - """GET /login should return password and SSO flows""" - channel = self.make_request("GET", "/_matrix/client/r0/login") - self.assertEqual(channel.code, 200, channel.result) - - expected_flow_types = [ - "m.login.cas", - "m.login.sso", - "m.login.token", - "m.login.password", - ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS] - - self.assertCountEqual( - [f["type"] for f in channel.json_body["flows"]], expected_flow_types - ) - - @override_config({"experimental_features": {"msc2858_enabled": True}}) - def test_get_msc2858_login_flows(self): - """The SSO flow should include IdP info if MSC2858 is enabled""" - channel = self.make_request("GET", "/_matrix/client/r0/login") - self.assertEqual(channel.code, 200, channel.result) - - # stick the flows results in a dict by type - flow_results: Dict[str, Any] = {} - for f in channel.json_body["flows"]: - flow_type = f["type"] - self.assertNotIn( - flow_type, flow_results, "duplicate flow type %s" % (flow_type,) - ) - flow_results[flow_type] = f - - self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned") - sso_flow = flow_results.pop("m.login.sso") - # we should have a set of IdPs - self.assertCountEqual( - sso_flow["org.matrix.msc2858.identity_providers"], - [ - {"id": "cas", "name": "CAS"}, - {"id": "saml", "name": "SAML"}, - {"id": "oidc-idp1", "name": "IDP1"}, - {"id": "oidc", "name": "OIDC"}, - ], - ) - - # the rest of the flows are simple - expected_flows = [ - {"type": "m.login.cas"}, - {"type": "m.login.token"}, - {"type": "m.login.password"}, - ] + ADDITIONAL_LOGIN_FLOWS - - self.assertCountEqual(flow_results.values(), expected_flows) - - def test_multi_sso_redirect(self): - """/login/sso/redirect should redirect to an identity picker""" - # first hit the redirect url, which should redirect to our idp picker - channel = self._make_sso_redirect_request(False, None) - self.assertEqual(channel.code, 302, channel.result) - uri = channel.headers.getRawHeaders("Location")[0] - - # hitting that picker should give us some HTML - channel = self.make_request("GET", uri) - self.assertEqual(channel.code, 200, channel.result) - - # parse the form to check it has fields assumed elsewhere in this class - html = channel.result["body"].decode("utf-8") - p = TestHtmlParser() - p.feed(html) - p.close() - - # there should be a link for each href - returned_idps: List[str] = [] - for link in p.links: - path, query = link.split("?", 1) - self.assertEqual(path, "pick_idp") - params = urllib.parse.parse_qs(query) - self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL]) - returned_idps.append(params["idp"][0]) - - self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) - - def test_multi_sso_redirect_to_cas(self): - """If CAS is chosen, should redirect to the CAS server""" - - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=cas", - shorthand=False, - ) - self.assertEqual(channel.code, 302, channel.result) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - cas_uri = location_headers[0] - cas_uri_path, cas_uri_query = cas_uri.split("?", 1) - - # it should redirect us to the login page of the cas server - self.assertEqual(cas_uri_path, CAS_SERVER + "/login") - - # check that the redirectUrl is correctly encoded in the service param - ie, the - # place that CAS will redirect to - cas_uri_params = urllib.parse.parse_qs(cas_uri_query) - service_uri = cas_uri_params["service"][0] - _, service_uri_query = service_uri.split("?", 1) - service_uri_params = urllib.parse.parse_qs(service_uri_query) - self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) - - def test_multi_sso_redirect_to_saml(self): - """If SAML is chosen, should redirect to the SAML server""" - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=saml", - ) - self.assertEqual(channel.code, 302, channel.result) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - saml_uri = location_headers[0] - saml_uri_path, saml_uri_query = saml_uri.split("?", 1) - - # it should redirect us to the login page of the SAML server - self.assertEqual(saml_uri_path, SAML_SERVER) - - # the RelayState is used to carry the client redirect url - saml_uri_params = urllib.parse.parse_qs(saml_uri_query) - relay_state_param = saml_uri_params["RelayState"][0] - self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - - def test_login_via_oidc(self): - """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - - # pick the default OIDC provider - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=oidc", - ) - self.assertEqual(channel.code, 302, channel.result) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - oidc_uri = location_headers[0] - oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) - - # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - - # ... and should have set a cookie including the redirect url - cookie_headers = channel.headers.getRawHeaders("Set-Cookie") - assert cookie_headers - cookies: Dict[str, str] = {} - for h in cookie_headers: - key, value = h.split(";")[0].split("=", maxsplit=1) - cookies[key] = value - - oidc_session_cookie = cookies["oidc_session"] - macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) - self.assertEqual( - self._get_value_from_macaroon(macaroon, "client_redirect_url"), - TEST_CLIENT_REDIRECT_URL, - ) - - channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) - - # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) - content_type_headers = channel.headers.getRawHeaders("Content-Type") - assert content_type_headers - self.assertTrue(content_type_headers[-1].startswith("text/html")) - p = TestHtmlParser() - p.feed(channel.text_body) - p.close() - - # ... which should contain our redirect link - self.assertEqual(len(p.links), 1) - path, query = p.links[0].split("?", 1) - self.assertEqual(path, "https://x") - - # it will have url-encoded the params properly, so we'll have to parse them - params = urllib.parse.parse_qsl( - query, keep_blank_values=True, strict_parsing=True, errors="strict" - ) - self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) - self.assertEqual(params[2][0], "loginToken") - - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token, mxid, and device id. - login_token = params[2][1] - chan = self.make_request( - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], "@user1:test") - - def test_multi_sso_redirect_to_unknown(self): - """An unknown IdP should cause a 400""" - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", - ) - self.assertEqual(channel.code, 400, channel.result) - - def test_client_idp_redirect_to_unknown(self): - """If the client tries to pick an unknown IdP, return a 404""" - channel = self._make_sso_redirect_request(False, "xxx") - self.assertEqual(channel.code, 404, channel.result) - self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - - def test_client_idp_redirect_to_oidc(self): - """If the client pick a known IdP, redirect to it""" - channel = self._make_sso_redirect_request(False, "oidc") - self.assertEqual(channel.code, 302, channel.result) - oidc_uri = channel.headers.getRawHeaders("Location")[0] - oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) - - # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - - @override_config({"experimental_features": {"msc2858_enabled": True}}) - def test_client_msc2858_redirect_to_oidc(self): - """Test the unstable API""" - channel = self._make_sso_redirect_request(True, "oidc") - self.assertEqual(channel.code, 302, channel.result) - oidc_uri = channel.headers.getRawHeaders("Location")[0] - oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) - - # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - - def test_client_idp_redirect_msc2858_disabled(self): - """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400""" - channel = self._make_sso_redirect_request(True, "oidc") - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") - - def _make_sso_redirect_request( - self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None - ): - """Send a request to /_matrix/client/r0/login/sso/redirect - - ... or the unstable equivalent - - ... possibly specifying an IDP provider - """ - endpoint = ( - "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect" - if unstable_endpoint - else "/_matrix/client/r0/login/sso/redirect" - ) - if idp_prov is not None: - endpoint += "/" + idp_prov - endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - - return self.make_request( - "GET", - endpoint, - custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)], - ) - - @staticmethod - def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: - prefix = key + " = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(prefix): - return caveat.caveat_id[len(prefix) :] - raise ValueError("No %s caveat in macaroon" % (key,)) - - -class CASTestCase(unittest.HomeserverTestCase): - - servlets = [ - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.base_url = "https://matrix.goodserver.com/" - self.redirect_path = "_synapse/client/login/sso/redirect/confirm" - - config = self.default_config() - config["public_baseurl"] = ( - config.get("public_baseurl") or "https://matrix.goodserver.com:8448" - ) - config["cas_config"] = { - "enabled": True, - "server_url": CAS_SERVER, - } - - cas_user_id = "username" - self.user_id = "@%s:test" % cas_user_id - - async def get_raw(uri, args): - """Return an example response payload from a call to the `/proxyValidate` - endpoint of a CAS server, copied from - https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 - - This needs to be returned by an async function (as opposed to set as the - mock's return value) because the corresponding Synapse code awaits on it. - """ - return ( - """ - - - %s - PGTIOU-84678-8a9d... - - https://proxy2/pgtUrl - https://proxy1/pgtUrl - - - - """ - % cas_user_id - ).encode("utf-8") - - mocked_http_client = Mock(spec=["get_raw"]) - mocked_http_client.get_raw.side_effect = get_raw - - self.hs = self.setup_test_homeserver( - config=config, - proxied_http_client=mocked_http_client, - ) - - return self.hs - - def prepare(self, reactor, clock, hs): - self.deactivate_account_handler = hs.get_deactivate_account_handler() - - def test_cas_redirect_confirm(self): - """Tests that the SSO login flow serves a confirmation page before redirecting a - user to the redirect URL. - """ - base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl" - redirect_url = "https://dodgy-site.com/" - - url_parts = list(urllib.parse.urlparse(base_url)) - query = dict(urllib.parse.parse_qsl(url_parts[4])) - query.update({"redirectUrl": redirect_url}) - query.update({"ticket": "ticket"}) - url_parts[4] = urllib.parse.urlencode(query) - cas_ticket_url = urllib.parse.urlunparse(url_parts) - - # Get Synapse to call the fake CAS and serve the template. - channel = self.make_request("GET", cas_ticket_url) - - # Test that the response is HTML. - self.assertEqual(channel.code, 200, channel.result) - content_type_header_value = "" - for header in channel.result.get("headers", []): - if header[0] == b"Content-Type": - content_type_header_value = header[1].decode("utf8") - - self.assertTrue(content_type_header_value.startswith("text/html")) - - # Test that the body isn't empty. - self.assertTrue(len(channel.result["body"]) > 0) - - # And that it contains our redirect link - self.assertIn(redirect_url, channel.result["body"].decode("UTF-8")) - - @override_config( - { - "sso": { - "client_whitelist": [ - "https://legit-site.com/", - "https://other-site.com/", - ] - } - } - ) - def test_cas_redirect_whitelisted(self): - """Tests that the SSO login flow serves a redirect to a whitelisted url""" - self._test_redirect("https://legit-site.com/") - - @override_config({"public_baseurl": "https://example.com"}) - def test_cas_redirect_login_fallback(self): - self._test_redirect("https://example.com/_matrix/static/client/login") - - def _test_redirect(self, redirect_url): - """Tests that the SSO login flow serves a redirect for the given redirect URL.""" - cas_ticket_url = ( - "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" - % (urllib.parse.quote(redirect_url)) - ) - - # Get Synapse to call the fake CAS and serve the template. - channel = self.make_request("GET", cas_ticket_url) - - self.assertEqual(channel.code, 302) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) - - @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) - def test_deactivated_user(self): - """Logging in as a deactivated account should error.""" - redirect_url = "https://legit-site.com/" - - # First login (to create the user). - self._test_redirect(redirect_url) - - # Deactivate the account. - self.get_success( - self.deactivate_account_handler.deactivate_account( - self.user_id, False, create_requester(self.user_id) - ) - ) - - # Request the CAS ticket. - cas_ticket_url = ( - "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" - % (urllib.parse.quote(redirect_url)) - ) - - # Get Synapse to call the fake CAS and serve the template. - channel = self.make_request("GET", cas_ticket_url) - - # Because the user is deactivated they are served an error template. - self.assertEqual(channel.code, 403) - self.assertIn(b"SSO account deactivated", channel.result["body"]) - - -@skip_unless(HAS_JWT, "requires jwt") -class JWTTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - ] - - jwt_secret = "secret" - jwt_algorithm = "HS256" - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt_enabled = True - self.hs.config.jwt_secret = self.jwt_secret - self.hs.config.jwt_algorithm = self.jwt_algorithm - return self.hs - - def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) - if isinstance(result, bytes): - return result.decode("ascii") - return result - - def jwt_login(self, *args): - params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - channel = self.make_request(b"POST", LOGIN_URL, params) - return channel - - def test_login_jwt_valid_registered(self): - self.register_user("kermit", "monkey") - channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - def test_login_jwt_valid_unregistered(self): - channel = self.jwt_login({"sub": "frog"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@frog:test") - - def test_login_jwt_invalid_signature(self): - channel = self.jwt_login({"sub": "frog"}, "notsecret") - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - "JWT validation failed: Signature verification failed", - ) - - def test_login_jwt_expired(self): - channel = self.jwt_login({"sub": "frog", "exp": 864000}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], "JWT validation failed: Signature has expired" - ) - - def test_login_jwt_not_before(self): - now = int(time.time()) - channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - "JWT validation failed: The token is not yet valid (nbf)", - ) - - def test_login_no_sub(self): - channel = self.jwt_login({"username": "root"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual(channel.json_body["error"], "Invalid JWT") - - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "issuer": "test-issuer", - } - } - ) - def test_login_iss(self): - """Test validating the issuer claim.""" - # A valid issuer. - channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - # An invalid issuer. - channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid issuer" - ) - - # Not providing an issuer. - channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - 'JWT validation failed: Token is missing the "iss" claim', - ) - - def test_login_iss_no_config(self): - """Test providing an issuer claim without requiring it in the configuration.""" - channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "audiences": ["test-audience"], - } - } - ) - def test_login_aud(self): - """Test validating the audience claim.""" - # A valid audience. - channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - # An invalid audience. - channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" - ) - - # Not providing an audience. - channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - 'JWT validation failed: Token is missing the "aud" claim', - ) - - def test_login_aud_no_config(self): - """Test providing an audience without requiring it in the configuration.""" - channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" - ) - - def test_login_no_token(self): - params = {"type": "org.matrix.login.jwt"} - channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") - - -# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use -# RSS256, with a public key configured in synapse as "jwt_secret", and tokens -# signed by the private key. -@skip_unless(HAS_JWT, "requires jwt") -class JWTPubKeyTestCase(unittest.HomeserverTestCase): - servlets = [ - login.register_servlets, - ] - - # This key's pubkey is used as the jwt_secret setting of synapse. Valid - # tokens are signed by this and validated using the pubkey. It is generated - # with `openssl genrsa 512` (not a secure way to generate real keys, but - # good enough for tests!) - jwt_privatekey = "\n".join( - [ - "-----BEGIN RSA PRIVATE KEY-----", - "MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB", - "492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk", - "yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/", - "kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq", - "TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN", - "ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA", - "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=", - "-----END RSA PRIVATE KEY-----", - ] - ) - - # Generated with `openssl rsa -in foo.key -pubout`, with the the above - # private key placed in foo.key (jwt_privatekey). - jwt_pubkey = "\n".join( - [ - "-----BEGIN PUBLIC KEY-----", - "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7", - "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==", - "-----END PUBLIC KEY-----", - ] - ) - - # This key is used to sign tokens that shouldn't be accepted by synapse. - # Generated just like jwt_privatekey. - bad_privatekey = "\n".join( - [ - "-----BEGIN RSA PRIVATE KEY-----", - "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv", - "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L", - "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY", - "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I", - "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb", - "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0", - "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m", - "-----END RSA PRIVATE KEY-----", - ] - ) - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt_enabled = True - self.hs.config.jwt_secret = self.jwt_pubkey - self.hs.config.jwt_algorithm = "RS256" - return self.hs - - def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") - if isinstance(result, bytes): - return result.decode("ascii") - return result - - def jwt_login(self, *args): - params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - channel = self.make_request(b"POST", LOGIN_URL, params) - return channel - - def test_login_jwt_valid(self): - channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - def test_login_jwt_invalid_signature(self): - channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - "JWT validation failed: Signature verification failed", - ) - - -AS_USER = "as_user_alice" - - -class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): - servlets = [ - login.register_servlets, - register.register_servlets, - ] - - def register_as_user(self, username): - self.make_request( - b"POST", - "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), - {"username": username}, - ) - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - - self.service = ApplicationService( - id="unique_identifier", - token="some_token", - hostname="example.com", - sender="@asbot:example.com", - namespaces={ - ApplicationService.NS_USERS: [ - {"regex": r"@as_user.*", "exclusive": False} - ], - ApplicationService.NS_ROOMS: [], - ApplicationService.NS_ALIASES: [], - }, - ) - self.another_service = ApplicationService( - id="another__identifier", - token="another_token", - hostname="example.com", - sender="@as2bot:example.com", - namespaces={ - ApplicationService.NS_USERS: [ - {"regex": r"@as2_user.*", "exclusive": False} - ], - ApplicationService.NS_ROOMS: [], - ApplicationService.NS_ALIASES: [], - }, - ) - - self.hs.get_datastore().services_cache.append(self.service) - self.hs.get_datastore().services_cache.append(self.another_service) - return self.hs - - def test_login_appservice_user(self): - """Test that an appservice user can use /login""" - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": AS_USER}, - } - channel = self.make_request( - b"POST", LOGIN_URL, params, access_token=self.service.token - ) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_login_appservice_user_bot(self): - """Test that the appservice bot can use /login""" - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": self.service.sender}, - } - channel = self.make_request( - b"POST", LOGIN_URL, params, access_token=self.service.token - ) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_login_appservice_wrong_user(self): - """Test that non-as users cannot login with the as token""" - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": "fibble_wibble"}, - } - channel = self.make_request( - b"POST", LOGIN_URL, params, access_token=self.service.token - ) - - self.assertEquals(channel.result["code"], b"403", channel.result) - - def test_login_appservice_wrong_as(self): - """Test that as users cannot login with wrong as token""" - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": AS_USER}, - } - channel = self.make_request( - b"POST", LOGIN_URL, params, access_token=self.another_service.token - ) - - self.assertEquals(channel.result["code"], b"403", channel.result) - - def test_login_appservice_no_token(self): - """Test that users must provide a token when using the appservice - login method - """ - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": AS_USER}, - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.result["code"], b"401", channel.result) - - -@skip_unless(HAS_OIDC, "requires OIDC") -class UsernamePickerTestCase(HomeserverTestCase): - """Tests for the username picker flow of SSO login""" - - servlets = [login.register_servlets] - - def default_config(self): - config = super().default_config() - config["public_baseurl"] = BASE_URL - - config["oidc_config"] = {} - config["oidc_config"].update(TEST_OIDC_CONFIG) - config["oidc_config"]["user_mapping_provider"] = { - "config": {"display_name_template": "{{ user.displayname }}"} - } - - # whitelist this client URI so we redirect straight to it rather than - # serving a confirmation page - config["sso"] = {"client_whitelist": ["https://x"]} - return config - - def create_resource_dict(self) -> Dict[str, Resource]: - d = super().create_resource_dict() - d.update(build_synapse_client_resource_tree(self.hs)) - return d - - def test_username_picker(self): - """Test the happy path of a username picker flow.""" - - # do the start of the login flow - channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL - ) - - # that should redirect to the username picker - self.assertEqual(channel.code, 302, channel.result) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - picker_url = location_headers[0] - self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") - - # ... with a username_mapping_session cookie - cookies: Dict[str, str] = {} - channel.extract_cookies(cookies) - self.assertIn("username_mapping_session", cookies) - session_id = cookies["username_mapping_session"] - - # introspect the sso handler a bit to check that the username mapping session - # looks ok. - username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions - self.assertIn( - session_id, - username_mapping_sessions, - "session id not found in map", - ) - session = username_mapping_sessions[session_id] - self.assertEqual(session.remote_user_id, "tester") - self.assertEqual(session.display_name, "Jonny") - self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) - - # the expiry time should be about 15 minutes away - expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) - self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) - - # Now, submit a username to the username picker, which should serve a redirect - # to the completion page - content = urlencode({b"username": b"bobby"}).encode("utf8") - chan = self.make_request( - "POST", - path=picker_url, - content=content, - content_is_form=True, - custom_headers=[ - ("Cookie", "username_mapping_session=" + session_id), - # old versions of twisted don't do form-parsing without a valid - # content-length header. - ("Content-Length", str(len(content))), - ], - ) - self.assertEqual(chan.code, 302, chan.result) - location_headers = chan.headers.getRawHeaders("Location") - assert location_headers - - # send a request to the completion page, which should 302 to the client redirectUrl - chan = self.make_request( - "GET", - path=location_headers[0], - custom_headers=[("Cookie", "username_mapping_session=" + session_id)], - ) - self.assertEqual(chan.code, 302, chan.result) - location_headers = chan.headers.getRawHeaders("Location") - assert location_headers - - # ensure that the returned location matches the requested redirect URL - path, query = location_headers[0].split("?", 1) - self.assertEqual(path, "https://x") - - # it will have url-encoded the params properly, so we'll have to parse them - params = urllib.parse.parse_qsl( - query, keep_blank_values=True, strict_parsing=True, errors="strict" - ) - self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) - self.assertEqual(params[2][0], "loginToken") - - # fish the login token out of the returned redirect uri - login_token = params[2][1] - - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token, mxid, and device id. - chan = self.make_request( - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py deleted file mode 100644 index 1d152352d1..0000000000 --- a/tests/rest/client/v1/test_presence.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest.mock import Mock - -from twisted.internet import defer - -from synapse.handlers.presence import PresenceHandler -from synapse.rest.client import presence -from synapse.types import UserID - -from tests import unittest - - -class PresenceTestCase(unittest.HomeserverTestCase): - """Tests presence REST API.""" - - user_id = "@sid:red" - - user = UserID.from_string(user_id) - servlets = [presence.register_servlets] - - def make_homeserver(self, reactor, clock): - - presence_handler = Mock(spec=PresenceHandler) - presence_handler.set_state.return_value = defer.succeed(None) - - hs = self.setup_test_homeserver( - "red", - federation_http_client=None, - federation_client=Mock(), - presence_handler=presence_handler, - ) - - return hs - - def test_put_presence(self): - """ - PUT to the status endpoint with use_presence enabled will call - set_state on the presence handler. - """ - self.hs.config.use_presence = True - - body = {"presence": "here", "status_msg": "beep boop"} - channel = self.make_request( - "PUT", "/presence/%s/status" % (self.user_id,), body - ) - - self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) - - @unittest.override_config({"use_presence": False}) - def test_put_presence_disabled(self): - """ - PUT to the status endpoint with use_presence disabled will NOT call - set_state on the presence handler. - """ - - body = {"presence": "here", "status_msg": "beep boop"} - channel = self.make_request( - "PUT", "/presence/%s/status" % (self.user_id,), body - ) - - self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py deleted file mode 100644 index 2860579c2e..0000000000 --- a/tests/rest/client/v1/test_profile.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests REST events for /profile paths.""" -from synapse.rest import admin -from synapse.rest.client import login, profile, room - -from tests import unittest - - -class ProfileTestCase(unittest.HomeserverTestCase): - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - profile.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - return self.hs - - def prepare(self, reactor, clock, hs): - self.owner = self.register_user("owner", "pass") - self.owner_tok = self.login("owner", "pass") - self.other = self.register_user("other", "pass", displayname="Bob") - - def test_get_displayname(self): - res = self._get_displayname() - self.assertEqual(res, "owner") - - def test_set_displayname(self): - channel = self.make_request( - "PUT", - "/profile/%s/displayname" % (self.owner,), - content={"displayname": "test"}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - res = self._get_displayname() - self.assertEqual(res, "test") - - def test_set_displayname_noauth(self): - channel = self.make_request( - "PUT", - "/profile/%s/displayname" % (self.owner,), - content={"displayname": "test"}, - ) - self.assertEqual(channel.code, 401, channel.result) - - def test_set_displayname_too_long(self): - """Attempts to set a stupid displayname should get a 400""" - channel = self.make_request( - "PUT", - "/profile/%s/displayname" % (self.owner,), - content={"displayname": "test" * 100}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - - res = self._get_displayname() - self.assertEqual(res, "owner") - - def test_get_displayname_other(self): - res = self._get_displayname(self.other) - self.assertEquals(res, "Bob") - - def test_set_displayname_other(self): - channel = self.make_request( - "PUT", - "/profile/%s/displayname" % (self.other,), - content={"displayname": "test"}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - - def test_get_avatar_url(self): - res = self._get_avatar_url() - self.assertIsNone(res) - - def test_set_avatar_url(self): - channel = self.make_request( - "PUT", - "/profile/%s/avatar_url" % (self.owner,), - content={"avatar_url": "http://my.server/pic.gif"}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - res = self._get_avatar_url() - self.assertEqual(res, "http://my.server/pic.gif") - - def test_set_avatar_url_noauth(self): - channel = self.make_request( - "PUT", - "/profile/%s/avatar_url" % (self.owner,), - content={"avatar_url": "http://my.server/pic.gif"}, - ) - self.assertEqual(channel.code, 401, channel.result) - - def test_set_avatar_url_too_long(self): - """Attempts to set a stupid avatar_url should get a 400""" - channel = self.make_request( - "PUT", - "/profile/%s/avatar_url" % (self.owner,), - content={"avatar_url": "http://my.server/pic.gif" * 100}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - - res = self._get_avatar_url() - self.assertIsNone(res) - - def test_get_avatar_url_other(self): - res = self._get_avatar_url(self.other) - self.assertIsNone(res) - - def test_set_avatar_url_other(self): - channel = self.make_request( - "PUT", - "/profile/%s/avatar_url" % (self.other,), - content={"avatar_url": "http://my.server/pic.gif"}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - - def _get_displayname(self, name=None): - channel = self.make_request( - "GET", "/profile/%s/displayname" % (name or self.owner,) - ) - self.assertEqual(channel.code, 200, channel.result) - return channel.json_body["displayname"] - - def _get_avatar_url(self, name=None): - channel = self.make_request( - "GET", "/profile/%s/avatar_url" % (name or self.owner,) - ) - self.assertEqual(channel.code, 200, channel.result) - return channel.json_body.get("avatar_url") - - -class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - profile.register_servlets, - room.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - - config = self.default_config() - config["require_auth_for_profile_requests"] = True - config["limit_profile_requests_to_users_who_share_rooms"] = True - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def prepare(self, reactor, clock, hs): - # User owning the requested profile. - self.owner = self.register_user("owner", "pass") - self.owner_tok = self.login("owner", "pass") - self.profile_url = "/profile/%s" % (self.owner) - - # User requesting the profile. - self.requester = self.register_user("requester", "pass") - self.requester_tok = self.login("requester", "pass") - - self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok) - - def test_no_auth(self): - self.try_fetch_profile(401) - - def test_not_in_shared_room(self): - self.ensure_requester_left_room() - - self.try_fetch_profile(403, access_token=self.requester_tok) - - def test_in_shared_room(self): - self.ensure_requester_left_room() - - self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) - - self.try_fetch_profile(200, self.requester_tok) - - def try_fetch_profile(self, expected_code, access_token=None): - self.request_profile(expected_code, access_token=access_token) - - self.request_profile( - expected_code, url_suffix="/displayname", access_token=access_token - ) - - self.request_profile( - expected_code, url_suffix="/avatar_url", access_token=access_token - ) - - def request_profile(self, expected_code, url_suffix="", access_token=None): - channel = self.make_request( - "GET", self.profile_url + url_suffix, access_token=access_token - ) - self.assertEqual(channel.code, expected_code, channel.result) - - def ensure_requester_left_room(self): - try: - self.helper.leave( - room=self.room_id, user=self.requester, tok=self.requester_tok - ) - except AssertionError: - # We don't care whether the leave request didn't return a 200 (e.g. - # if the user isn't already in the room), because we only want to - # make sure the user isn't in the room. - pass - - -class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - profile.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["require_auth_for_profile_requests"] = True - config["limit_profile_requests_to_users_who_share_rooms"] = True - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def prepare(self, reactor, clock, hs): - # User requesting the profile. - self.requester = self.register_user("requester", "pass") - self.requester_tok = self.login("requester", "pass") - - def test_can_lookup_own_profile(self): - """Tests that a user can lookup their own profile without having to be in a room - if 'require_auth_for_profile_requests' is set to true in the server's config. - """ - channel = self.make_request( - "GET", "/profile/" + self.requester, access_token=self.requester_tok - ) - self.assertEqual(channel.code, 200, channel.result) - - channel = self.make_request( - "GET", - "/profile/" + self.requester + "/displayname", - access_token=self.requester_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - channel = self.make_request( - "GET", - "/profile/" + self.requester + "/avatar_url", - access_token=self.requester_tok, - ) - self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py deleted file mode 100644 index d0ce91ccd9..0000000000 --- a/tests/rest/client/v1/test_push_rule_attrs.py +++ /dev/null @@ -1,414 +0,0 @@ -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import synapse -from synapse.api.errors import Codes -from synapse.rest.client import login, push_rule, room - -from tests.unittest import HomeserverTestCase - - -class PushRuleAttributesTestCase(HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - push_rule.register_servlets, - ] - hijack_auth = False - - def test_enabled_on_creation(self): - """ - Tests the GET and PUT of push rules' `enabled` endpoints. - Tests that a rule is enabled upon creation, even though a rule with that - ruleId existed previously and was disabled. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # GET enabled for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], True) - - def test_enabled_on_recreation(self): - """ - Tests the GET and PUT of push rules' `enabled` endpoints. - Tests that a rule is enabled upon creation, even if a rule with that - ruleId existed previously and was disabled. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # disable the rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/enabled", - {"enabled": False}, - access_token=token, - ) - self.assertEqual(channel.code, 200) - - # check rule disabled - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], False) - - # DELETE the rule - channel = self.make_request( - "DELETE", "/pushrules/global/override/best.friend", access_token=token - ) - self.assertEqual(channel.code, 200) - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # GET enabled for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], True) - - def test_enabled_disable(self): - """ - Tests the GET and PUT of push rules' `enabled` endpoints. - Tests that a rule is disabled and enabled when we ask for it. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # disable the rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/enabled", - {"enabled": False}, - access_token=token, - ) - self.assertEqual(channel.code, 200) - - # check rule disabled - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], False) - - # re-enable the rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/enabled", - {"enabled": True}, - access_token=token, - ) - self.assertEqual(channel.code, 200) - - # check rule enabled - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], True) - - def test_enabled_404_when_get_non_existent(self): - """ - Tests that `enabled` gives 404 when the rule doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # check 404 for never-heard-of rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # GET enabled for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - - # DELETE the rule - channel = self.make_request( - "DELETE", "/pushrules/global/override/best.friend", access_token=token - ) - self.assertEqual(channel.code, 200) - - # check 404 for deleted rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_enabled_404_when_get_non_existent_server_rule(self): - """ - Tests that `enabled` gives 404 when the server-default rule doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # check 404 for never-heard-of rule - channel = self.make_request( - "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_enabled_404_when_put_non_existent_rule(self): - """ - Tests that `enabled` gives 404 when we put to a rule that doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # enable & check 404 for never-heard-of rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/enabled", - {"enabled": True}, - access_token=token, - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_enabled_404_when_put_non_existent_server_rule(self): - """ - Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # enable & check 404 for never-heard-of rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/.m.muahahah/enabled", - {"enabled": True}, - access_token=token, - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_actions_get(self): - """ - Tests that `actions` gives you what you expect on a fresh rule. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # GET actions for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/actions", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] - ) - - def test_actions_put(self): - """ - Tests that PUT on actions updates the value you'd get from GET. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # change the rule actions - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/actions", - {"actions": ["dont_notify"]}, - access_token=token, - ) - self.assertEqual(channel.code, 200) - - # GET actions for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/actions", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["actions"], ["dont_notify"]) - - def test_actions_404_when_get_non_existent(self): - """ - Tests that `actions` gives 404 when the rule doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # check 404 for never-heard-of rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # DELETE the rule - channel = self.make_request( - "DELETE", "/pushrules/global/override/best.friend", access_token=token - ) - self.assertEqual(channel.code, 200) - - # check 404 for deleted rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_actions_404_when_get_non_existent_server_rule(self): - """ - Tests that `actions` gives 404 when the server-default rule doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # check 404 for never-heard-of rule - channel = self.make_request( - "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_actions_404_when_put_non_existent_rule(self): - """ - Tests that `actions` gives 404 when putting to a rule that doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # enable & check 404 for never-heard-of rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/actions", - {"actions": ["dont_notify"]}, - access_token=token, - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_actions_404_when_put_non_existent_server_rule(self): - """ - Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # enable & check 404 for never-heard-of rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/.m.muahahah/actions", - {"actions": ["dont_notify"]}, - access_token=token, - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py deleted file mode 100644 index 0c9cbb9aff..0000000000 --- a/tests/rest/client/v1/test_rooms.py +++ /dev/null @@ -1,2150 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests REST events for /rooms paths.""" - -import json -from typing import Iterable -from unittest.mock import Mock, call -from urllib import parse as urlparse - -from twisted.internet import defer - -import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes, Membership -from synapse.api.errors import HttpResponseException -from synapse.handlers.pagination import PurgeStatus -from synapse.rest import admin -from synapse.rest.client import account, directory, login, profile, room -from synapse.types import JsonDict, RoomAlias, UserID, create_requester -from synapse.util.stringutils import random_string - -from tests import unittest -from tests.test_utils import make_awaitable - -PATH_PREFIX = b"/_matrix/client/api/v1" - - -class RoomBase(unittest.HomeserverTestCase): - rmcreator_id = None - - servlets = [room.register_servlets, room.register_deprecated_servlets] - - def make_homeserver(self, reactor, clock): - - self.hs = self.setup_test_homeserver( - "red", - federation_http_client=None, - federation_client=Mock(), - ) - - self.hs.get_federation_handler = Mock() - self.hs.get_federation_handler.return_value.maybe_backfill = Mock( - return_value=make_awaitable(None) - ) - - async def _insert_client_ip(*args, **kwargs): - return None - - self.hs.get_datastore().insert_client_ip = _insert_client_ip - - return self.hs - - -class RoomPermissionsTestCase(RoomBase): - """Tests room permissions.""" - - user_id = "@sid1:red" - rmcreator_id = "@notme:red" - - def prepare(self, reactor, clock, hs): - - self.helper.auth_user_id = self.rmcreator_id - # create some rooms under the name rmcreator_id - self.uncreated_rmid = "!aa:test" - self.created_rmid = self.helper.create_room_as( - self.rmcreator_id, is_public=False - ) - self.created_public_rmid = self.helper.create_room_as( - self.rmcreator_id, is_public=True - ) - - # send a message in one of the rooms - self.created_rmid_msg_path = ( - "rooms/%s/send/m.room.message/a1" % (self.created_rmid) - ).encode("ascii") - channel = self.make_request( - "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' - ) - self.assertEquals(200, channel.code, channel.result) - - # set topic for public room - channel = self.make_request( - "PUT", - ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), - b'{"topic":"Public Room Topic"}', - ) - self.assertEquals(200, channel.code, channel.result) - - # auth as user_id now - self.helper.auth_user_id = self.user_id - - def test_can_do_action(self): - msg_content = b'{"msgtype":"m.text","body":"hello"}' - - seq = iter(range(100)) - - def send_msg_path(): - return "/rooms/%s/send/m.room.message/mid%s" % ( - self.created_rmid, - str(next(seq)), - ) - - # send message in uncreated room, expect 403 - channel = self.make_request( - "PUT", - "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), - msg_content, - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # send message in created room not joined (no state), expect 403 - channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # send message in created room and invited, expect 403 - self.helper.invite( - room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id - ) - channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # send message in created room and joined, expect 200 - self.helper.join(room=self.created_rmid, user=self.user_id) - channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # send message in created room and left, expect 403 - self.helper.leave(room=self.created_rmid, user=self.user_id) - channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - def test_topic_perms(self): - topic_content = b'{"topic":"My Topic Name"}' - topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid - - # set/get topic in uncreated room, expect 403 - channel = self.make_request( - "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - channel = self.make_request( - "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # set/get topic in created PRIVATE room not joined, expect 403 - channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # set topic in created PRIVATE room and invited, expect 403 - self.helper.invite( - room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id - ) - channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # get topic in created PRIVATE room and invited, expect 403 - channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # set/get topic in created PRIVATE room and joined, expect 200 - self.helper.join(room=self.created_rmid, user=self.user_id) - - # Only room ops can set topic by default - self.helper.auth_user_id = self.rmcreator_id - channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.helper.auth_user_id = self.user_id - - channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) - - # set/get topic in created PRIVATE room and left, expect 403 - self.helper.leave(room=self.created_rmid, user=self.user_id) - channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # get topic in PUBLIC room, not joined, expect 403 - channel = self.make_request( - "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # set topic in PUBLIC room, not joined, expect 403 - channel = self.make_request( - "PUT", - "/rooms/%s/state/m.room.topic" % self.created_public_rmid, - topic_content, - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - def _test_get_membership( - self, room=None, members: Iterable = frozenset(), expect_code=None - ): - for member in members: - path = "/rooms/%s/state/m.room.member/%s" % (room, member) - channel = self.make_request("GET", path) - self.assertEquals(expect_code, channel.code) - - def test_membership_basic_room_perms(self): - # === room does not exist === - room = self.uncreated_rmid - # get membership of self, get membership of other, uncreated room - # expect all 403s - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 - ) - - # trying to invite people to this room should 403 - self.helper.invite( - room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 - ) - - # set [invite/join/left] of self, set [invite/join/left] of other, - # expect all 404s because room doesn't exist on any server - for usr in [self.user_id, self.rmcreator_id]: - self.helper.join(room=room, user=usr, expect_code=404) - self.helper.leave(room=room, user=usr, expect_code=404) - - def test_membership_private_room_perms(self): - room = self.created_rmid - # get membership of self, get membership of other, private room + invite - # expect all 403s - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 - ) - - # get membership of self, get membership of other, private room + joined - # expect all 200s - self.helper.join(room=room, user=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 - ) - - # get membership of self, get membership of other, private room + left - # expect all 200s - self.helper.leave(room=room, user=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 - ) - - def test_membership_public_room_perms(self): - room = self.created_public_rmid - # get membership of self, get membership of other, public room + invite - # expect 403 - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 - ) - - # get membership of self, get membership of other, public room + joined - # expect all 200s - self.helper.join(room=room, user=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 - ) - - # get membership of self, get membership of other, public room + left - # expect 200. - self.helper.leave(room=room, user=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 - ) - - def test_invited_permissions(self): - room = self.created_rmid - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - - # set [invite/join/left] of other user, expect 403s - self.helper.invite( - room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 - ) - self.helper.change_membership( - room=room, - src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.JOIN, - expect_code=403, - ) - self.helper.change_membership( - room=room, - src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.LEAVE, - expect_code=403, - ) - - def test_joined_permissions(self): - room = self.created_rmid - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - self.helper.join(room=room, user=self.user_id) - - # set invited of self, expect 403 - self.helper.invite( - room=room, src=self.user_id, targ=self.user_id, expect_code=403 - ) - - # set joined of self, expect 200 (NOOP) - self.helper.join(room=room, user=self.user_id) - - other = "@burgundy:red" - # set invited of other, expect 200 - self.helper.invite(room=room, src=self.user_id, targ=other, expect_code=200) - - # set joined of other, expect 403 - self.helper.change_membership( - room=room, - src=self.user_id, - targ=other, - membership=Membership.JOIN, - expect_code=403, - ) - - # set left of other, expect 403 - self.helper.change_membership( - room=room, - src=self.user_id, - targ=other, - membership=Membership.LEAVE, - expect_code=403, - ) - - # set left of self, expect 200 - self.helper.leave(room=room, user=self.user_id) - - def test_leave_permissions(self): - room = self.created_rmid - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - self.helper.join(room=room, user=self.user_id) - self.helper.leave(room=room, user=self.user_id) - - # set [invite/join/left] of self, set [invite/join/left] of other, - # expect all 403s - for usr in [self.user_id, self.rmcreator_id]: - self.helper.change_membership( - room=room, - src=self.user_id, - targ=usr, - membership=Membership.INVITE, - expect_code=403, - ) - - self.helper.change_membership( - room=room, - src=self.user_id, - targ=usr, - membership=Membership.JOIN, - expect_code=403, - ) - - # It is always valid to LEAVE if you've already left (currently.) - self.helper.change_membership( - room=room, - src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.LEAVE, - expect_code=403, - ) - - -class RoomsMemberListTestCase(RoomBase): - """Tests /rooms/$room_id/members/list REST events.""" - - user_id = "@sid1:red" - - def test_get_member_list(self): - room_id = self.helper.create_room_as(self.user_id) - channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - def test_get_member_list_no_room(self): - channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - def test_get_member_list_no_permission(self): - room_id = self.helper.create_room_as("@some_other_guy:red") - channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - def test_get_member_list_mixed_memberships(self): - room_creator = "@some_other_guy:red" - room_id = self.helper.create_room_as(room_creator) - room_path = "/rooms/%s/members" % room_id - self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) - # can't see list if you're just invited. - channel = self.make_request("GET", room_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - self.helper.join(room=room_id, user=self.user_id) - # can see list now joined - channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - self.helper.leave(room=room_id, user=self.user_id) - # can see old list once left - channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - -class RoomsCreateTestCase(RoomBase): - """Tests /rooms and /rooms/$room_id REST events.""" - - user_id = "@sid1:red" - - def test_post_room_no_keys(self): - # POST with no config keys, expect new room id - channel = self.make_request("POST", "/createRoom", "{}") - - self.assertEquals(200, channel.code, channel.result) - self.assertTrue("room_id" in channel.json_body) - - def test_post_room_visibility_key(self): - # POST with visibility config key, expect new room id - channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') - self.assertEquals(200, channel.code) - self.assertTrue("room_id" in channel.json_body) - - def test_post_room_custom_key(self): - # POST with custom config keys, expect new room id - channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') - self.assertEquals(200, channel.code) - self.assertTrue("room_id" in channel.json_body) - - def test_post_room_known_and_unknown_keys(self): - # POST with custom + known config keys, expect new room id - channel = self.make_request( - "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' - ) - self.assertEquals(200, channel.code) - self.assertTrue("room_id" in channel.json_body) - - def test_post_room_invalid_content(self): - # POST with invalid content / paths, expect 400 - channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.assertEquals(400, channel.code) - - channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.assertEquals(400, channel.code) - - def test_post_room_invitees_invalid_mxid(self): - # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 - # Note the trailing space in the MXID here! - channel = self.make_request( - "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' - ) - self.assertEquals(400, channel.code) - - @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) - def test_post_room_invitees_ratelimit(self): - """Test that invites sent when creating a room are ratelimited by a RateLimiter, - which ratelimits them correctly, including by not limiting when the requester is - exempt from ratelimiting. - """ - - # Build the request's content. We use local MXIDs because invites over federation - # are more difficult to mock. - content = json.dumps( - { - "invite": [ - "@alice1:red", - "@alice2:red", - "@alice3:red", - "@alice4:red", - ] - } - ).encode("utf8") - - # Test that the invites are correctly ratelimited. - channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(400, channel.code) - self.assertEqual( - "Cannot invite so many users at once", - channel.json_body["error"], - ) - - # Add the current user to the ratelimit overrides, allowing them no ratelimiting. - self.get_success( - self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0) - ) - - # Test that the invites aren't ratelimited anymore. - channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(200, channel.code) - - -class RoomTopicTestCase(RoomBase): - """Tests /rooms/$room_id/topic REST events.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - # create the room - self.room_id = self.helper.create_room_as(self.user_id) - self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) - - def test_invalid_puts(self): - # missing keys or invalid json - channel = self.make_request("PUT", self.path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", self.path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request( - "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' - ) - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", self.path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", self.path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - # valid key, wrong type - content = '{"topic":["Topic name"]}' - channel = self.make_request("PUT", self.path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - def test_rooms_topic(self): - # nothing should be there - channel = self.make_request("GET", self.path) - self.assertEquals(404, channel.code, msg=channel.result["body"]) - - # valid put - content = '{"topic":"Topic name"}' - channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # valid get - channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assert_dict(json.loads(content), channel.json_body) - - def test_rooms_topic_with_extra_keys(self): - # valid put with extra keys - content = '{"topic":"Seasons","subtopic":"Summer"}' - channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # valid get - channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assert_dict(json.loads(content), channel.json_body) - - -class RoomMemberStateTestCase(RoomBase): - """Tests /rooms/$room_id/members/$user_id/state REST events.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - self.room_id = self.helper.create_room_as(self.user_id) - - def test_invalid_puts(self): - path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) - # missing keys or invalid json - channel = self.make_request("PUT", path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - # valid keys, wrong types - content = '{"membership":["%s","%s","%s"]}' % ( - Membership.INVITE, - Membership.JOIN, - Membership.LEAVE, - ) - channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - def test_rooms_members_self(self): - path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), - self.user_id, - ) - - # valid join message (NOOP since we made the room) - content = '{"membership":"%s"}' % Membership.JOIN - channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - expected_response = {"membership": Membership.JOIN} - self.assertEquals(expected_response, channel.json_body) - - def test_rooms_members_other(self): - self.other_id = "@zzsid1:red" - path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), - self.other_id, - ) - - # valid invite message - content = '{"membership":"%s"}' % Membership.INVITE - channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) - - def test_rooms_members_other_custom_keys(self): - self.other_id = "@zzsid1:red" - path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), - self.other_id, - ) - - # valid invite message with custom key - content = '{"membership":"%s","invite_text":"%s"}' % ( - Membership.INVITE, - "Join us!", - ) - channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) - - -class RoomInviteRatelimitTestCase(RoomBase): - user_id = "@sid1:red" - - servlets = [ - admin.register_servlets, - profile.register_servlets, - room.register_servlets, - ] - - @unittest.override_config( - {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_invites_by_rooms_ratelimit(self): - """Tests that invites in a room are actually rate-limited.""" - room_id = self.helper.create_room_as(self.user_id) - - for i in range(3): - self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,)) - - self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429) - - @unittest.override_config( - {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_invites_by_users_ratelimit(self): - """Tests that invites to a specific user are actually rate-limited.""" - - for _ in range(3): - room_id = self.helper.create_room_as(self.user_id) - self.helper.invite(room_id, self.user_id, "@other-users:red") - - room_id = self.helper.create_room_as(self.user_id) - self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429) - - -class RoomJoinRatelimitTestCase(RoomBase): - user_id = "@sid1:red" - - servlets = [ - admin.register_servlets, - profile.register_servlets, - room.register_servlets, - ] - - @unittest.override_config( - {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_join_local_ratelimit(self): - """Tests that local joins are actually rate-limited.""" - for _ in range(3): - self.helper.create_room_as(self.user_id) - - self.helper.create_room_as(self.user_id, expect_code=429) - - @unittest.override_config( - {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_join_local_ratelimit_profile_change(self): - """Tests that sending a profile update into all of the user's joined rooms isn't - rate-limited by the rate-limiter on joins.""" - - # Create and join as many rooms as the rate-limiting config allows in a second. - room_ids = [ - self.helper.create_room_as(self.user_id), - self.helper.create_room_as(self.user_id), - self.helper.create_room_as(self.user_id), - ] - # Let some time for the rate-limiter to forget about our multi-join. - self.reactor.advance(2) - # Add one to make sure we're joined to more rooms than the config allows us to - # join in a second. - room_ids.append(self.helper.create_room_as(self.user_id)) - - # Create a profile for the user, since it hasn't been done on registration. - store = self.hs.get_datastore() - self.get_success( - store.create_profile(UserID.from_string(self.user_id).localpart) - ) - - # Update the display name for the user. - path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id - channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.assertEquals(channel.code, 200, channel.json_body) - - # Check that all the rooms have been sent a profile update into. - for room_id in room_ids: - path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( - room_id, - self.user_id, - ) - - channel = self.make_request("GET", path) - self.assertEquals(channel.code, 200) - - self.assertIn("displayname", channel.json_body) - self.assertEquals(channel.json_body["displayname"], "John Doe") - - @unittest.override_config( - {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_join_local_ratelimit_idempotent(self): - """Tests that the room join endpoints remain idempotent despite rate-limiting - on room joins.""" - room_id = self.helper.create_room_as(self.user_id) - - # Let's test both paths to be sure. - paths_to_test = [ - "/_matrix/client/r0/rooms/%s/join", - "/_matrix/client/r0/join/%s", - ] - - for path in paths_to_test: - # Make sure we send more requests than the rate-limiting config would allow - # if all of these requests ended up joining the user to a room. - for _ in range(4): - channel = self.make_request("POST", path % room_id, {}) - self.assertEquals(channel.code, 200) - - @unittest.override_config( - { - "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}, - "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"], - "autocreate_auto_join_rooms": True, - }, - ) - def test_autojoin_rooms(self): - user_id = self.register_user("testuser", "password") - - # Check that the new user successfully joined the four rooms - rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id)) - self.assertEqual(len(rooms), 4) - - -class RoomMessagesTestCase(RoomBase): - """Tests /rooms/$room_id/messages/$user_id/$msg_id REST events.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - self.room_id = self.helper.create_room_as(self.user_id) - - def test_invalid_puts(self): - path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) - # missing keys or invalid json - channel = self.make_request("PUT", path, b"{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b'{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b"text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b"") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - def test_rooms_messages_sent(self): - path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) - - content = b'{"body":"test","msgtype":{"type":"a"}}' - channel = self.make_request("PUT", path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - # custom message types - content = b'{"body":"test","msgtype":"test.custom.text"}' - channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # m.text message type - path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) - content = b'{"body":"test2","msgtype":"m.text"}' - channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - -class RoomInitialSyncTestCase(RoomBase): - """Tests /rooms/$room_id/initialSync.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - # create the room - self.room_id = self.helper.create_room_as(self.user_id) - - def test_initial_sync(self): - channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) - self.assertEquals(200, channel.code) - - self.assertEquals(self.room_id, channel.json_body["room_id"]) - self.assertEquals("join", channel.json_body["membership"]) - - # Room state is easier to assert on if we unpack it into a dict - state = {} - for event in channel.json_body["state"]: - if "state_key" not in event: - continue - t = event["type"] - if t not in state: - state[t] = [] - state[t].append(event) - - self.assertTrue("m.room.create" in state) - - self.assertTrue("messages" in channel.json_body) - self.assertTrue("chunk" in channel.json_body["messages"]) - self.assertTrue("end" in channel.json_body["messages"]) - - self.assertTrue("presence" in channel.json_body) - - presence_by_user = { - e["content"]["user_id"]: e for e in channel.json_body["presence"] - } - self.assertTrue(self.user_id in presence_by_user) - self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) - - -class RoomMessageListTestCase(RoomBase): - """Tests /rooms/$room_id/messages REST events.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - self.room_id = self.helper.create_room_as(self.user_id) - - def test_topo_token_is_accepted(self): - token = "t1-0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) - ) - self.assertEquals(200, channel.code) - self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) - self.assertTrue("chunk" in channel.json_body) - self.assertTrue("end" in channel.json_body) - - def test_stream_token_is_accepted_for_fwd_pagianation(self): - token = "s0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) - ) - self.assertEquals(200, channel.code) - self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) - self.assertTrue("chunk" in channel.json_body) - self.assertTrue("end" in channel.json_body) - - def test_room_messages_purge(self): - store = self.hs.get_datastore() - pagination_handler = self.hs.get_pagination_handler() - - # Send a first message in the room, which will be removed by the purge. - first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] - first_token = self.get_success( - store.get_topological_token_for_event(first_event_id) - ) - first_token_str = self.get_success(first_token.to_string(store)) - - # Send a second message in the room, which won't be removed, and which we'll - # use as the marker to purge events before. - second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] - second_token = self.get_success( - store.get_topological_token_for_event(second_event_id) - ) - second_token_str = self.get_success(second_token.to_string(store)) - - # Send a third event in the room to ensure we don't fall under any edge case - # due to our marker being the latest forward extremity in the room. - self.helper.send(self.room_id, "message 3") - - # Check that we get the first and second message when querying /messages. - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % ( - self.room_id, - second_token_str, - json.dumps({"types": [EventTypes.Message]}), - ), - ) - self.assertEqual(channel.code, 200, channel.json_body) - - chunk = channel.json_body["chunk"] - self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) - - # Purge every event before the second event. - purge_id = random_string(16) - pagination_handler._purges_by_id[purge_id] = PurgeStatus() - self.get_success( - pagination_handler._purge_history( - purge_id=purge_id, - room_id=self.room_id, - token=second_token_str, - delete_local_events=True, - ) - ) - - # Check that we only get the second message through /message now that the first - # has been purged. - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % ( - self.room_id, - second_token_str, - json.dumps({"types": [EventTypes.Message]}), - ), - ) - self.assertEqual(channel.code, 200, channel.json_body) - - chunk = channel.json_body["chunk"] - self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) - - # Check that we get no event, but also no error, when querying /messages with - # the token that was pointing at the first event, because we don't have it - # anymore. - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % ( - self.room_id, - first_token_str, - json.dumps({"types": [EventTypes.Message]}), - ), - ) - self.assertEqual(channel.code, 200, channel.json_body) - - chunk = channel.json_body["chunk"] - self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) - - -class RoomSearchTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - user_id = True - hijack_auth = False - - def prepare(self, reactor, clock, hs): - - # Register the user who does the searching - self.user_id = self.register_user("user", "pass") - self.access_token = self.login("user", "pass") - - # Register the user who sends the message - self.other_user_id = self.register_user("otheruser", "pass") - self.other_access_token = self.login("otheruser", "pass") - - # Create a room - self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) - - # Invite the other person - self.helper.invite( - room=self.room, - src=self.user_id, - tok=self.access_token, - targ=self.other_user_id, - ) - - # The other user joins - self.helper.join( - room=self.room, user=self.other_user_id, tok=self.other_access_token - ) - - def test_finds_message(self): - """ - The search functionality will search for content in messages if asked to - do so. - """ - # The other user sends some messages - self.helper.send(self.room, body="Hi!", tok=self.other_access_token) - self.helper.send(self.room, body="There!", tok=self.other_access_token) - - channel = self.make_request( - "POST", - "/search?access_token=%s" % (self.access_token,), - { - "search_categories": { - "room_events": {"keys": ["content.body"], "search_term": "Hi"} - } - }, - ) - - # Check we get the results we expect -- one search result, of the sent - # messages - self.assertEqual(channel.code, 200) - results = channel.json_body["search_categories"]["room_events"] - self.assertEqual(results["count"], 1) - self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") - - # No context was requested, so we should get none. - self.assertEqual(results["results"][0]["context"], {}) - - def test_include_context(self): - """ - When event_context includes include_profile, profile information will be - included in the search response. - """ - # The other user sends some messages - self.helper.send(self.room, body="Hi!", tok=self.other_access_token) - self.helper.send(self.room, body="There!", tok=self.other_access_token) - - channel = self.make_request( - "POST", - "/search?access_token=%s" % (self.access_token,), - { - "search_categories": { - "room_events": { - "keys": ["content.body"], - "search_term": "Hi", - "event_context": {"include_profile": True}, - } - } - }, - ) - - # Check we get the results we expect -- one search result, of the sent - # messages - self.assertEqual(channel.code, 200) - results = channel.json_body["search_categories"]["room_events"] - self.assertEqual(results["count"], 1) - self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") - - # We should get context info, like the two users, and the display names. - context = results["results"][0]["context"] - self.assertEqual(len(context["profile_info"].keys()), 2) - self.assertEqual( - context["profile_info"][self.other_user_id]["displayname"], "otheruser" - ) - - -class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - - self.url = b"/_matrix/client/r0/publicRooms" - - config = self.default_config() - config["allow_public_rooms_without_auth"] = False - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def test_restricted_no_auth(self): - channel = self.make_request("GET", self.url) - self.assertEqual(channel.code, 401, channel.result) - - def test_restricted_auth(self): - self.register_user("user", "pass") - tok = self.login("user", "pass") - - channel = self.make_request("GET", self.url, access_token=tok) - self.assertEqual(channel.code, 200, channel.result) - - -class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): - """Test that we correctly fallback to local filtering if a remote server - doesn't support search. - """ - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - return self.setup_test_homeserver(federation_client=Mock()) - - def prepare(self, reactor, clock, hs): - self.register_user("user", "pass") - self.token = self.login("user", "pass") - - self.federation_client = hs.get_federation_client() - - def test_simple(self): - "Simple test for searching rooms over federation" - self.federation_client.get_public_rooms.side_effect = ( - lambda *a, **k: defer.succeed({}) - ) - - search_filter = {"generic_search_term": "foobar"} - - channel = self.make_request( - "POST", - b"/_matrix/client/r0/publicRooms?server=testserv", - content={"filter": search_filter}, - access_token=self.token, - ) - self.assertEqual(channel.code, 200, channel.result) - - self.federation_client.get_public_rooms.assert_called_once_with( - "testserv", - limit=100, - since_token=None, - search_filter=search_filter, - include_all_networks=False, - third_party_instance_id=None, - ) - - def test_fallback(self): - "Test that searching public rooms over federation falls back if it gets a 404" - - # The `get_public_rooms` should be called again if the first call fails - # with a 404, when using search filters. - self.federation_client.get_public_rooms.side_effect = ( - HttpResponseException(404, "Not Found", b""), - defer.succeed({}), - ) - - search_filter = {"generic_search_term": "foobar"} - - channel = self.make_request( - "POST", - b"/_matrix/client/r0/publicRooms?server=testserv", - content={"filter": search_filter}, - access_token=self.token, - ) - self.assertEqual(channel.code, 200, channel.result) - - self.federation_client.get_public_rooms.assert_has_calls( - [ - call( - "testserv", - limit=100, - since_token=None, - search_filter=search_filter, - include_all_networks=False, - third_party_instance_id=None, - ), - call( - "testserv", - limit=None, - since_token=None, - search_filter=None, - include_all_networks=False, - third_party_instance_id=None, - ), - ] - ) - - -class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - profile.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["allow_per_room_profiles"] = False - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def prepare(self, reactor, clock, homeserver): - self.user_id = self.register_user("test", "test") - self.tok = self.login("test", "test") - - # Set a profile for the test user - self.displayname = "test user" - data = {"displayname": self.displayname} - request_data = json.dumps(data) - channel = self.make_request( - "PUT", - "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), - request_data, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - def test_per_room_profile_forbidden(self): - data = {"membership": "join", "displayname": "other test user"} - request_data = json.dumps(data) - channel = self.make_request( - "PUT", - "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" - % (self.room_id, self.user_id), - request_data, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - event_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - res_displayname = channel.json_body["content"]["displayname"] - self.assertEqual(res_displayname, self.displayname, channel.result) - - -class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): - """Tests that clients can add a "reason" field to membership events and - that they get correctly added to the generated events and propagated. - """ - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.creator = self.register_user("creator", "test") - self.creator_tok = self.login("creator", "test") - - self.second_user_id = self.register_user("second", "test") - self.second_tok = self.login("second", "test") - - self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) - - def test_join_reason(self): - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/join", - content={"reason": reason}, - access_token=self.second_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_leave_reason(self): - self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) - - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/leave", - content={"reason": reason}, - access_token=self.second_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_kick_reason(self): - self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) - - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/kick", - content={"reason": reason, "user_id": self.second_user_id}, - access_token=self.second_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_ban_reason(self): - self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) - - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/ban", - content={"reason": reason, "user_id": self.second_user_id}, - access_token=self.creator_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_unban_reason(self): - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/unban", - content={"reason": reason, "user_id": self.second_user_id}, - access_token=self.creator_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_invite_reason(self): - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/invite", - content={"reason": reason, "user_id": self.second_user_id}, - access_token=self.creator_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_reject_invite_reason(self): - self.helper.invite( - self.room_id, - src=self.creator, - targ=self.second_user_id, - tok=self.creator_tok, - ) - - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/leave", - content={"reason": reason}, - access_token=self.second_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def _check_for_reason(self, reason): - channel = self.make_request( - "GET", - "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( - self.room_id, self.second_user_id - ), - access_token=self.creator_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - event_content = channel.json_body - - self.assertEqual(event_content.get("reason"), reason, channel.result) - - -class LabelsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - profile.register_servlets, - ] - - # Filter that should only catch messages with the label "#fun". - FILTER_LABELS = { - "types": [EventTypes.Message], - "org.matrix.labels": ["#fun"], - } - # Filter that should only catch messages without the label "#fun". - FILTER_NOT_LABELS = { - "types": [EventTypes.Message], - "org.matrix.not_labels": ["#fun"], - } - # Filter that should only catch messages with the label "#work" but without the label - # "#notfun". - FILTER_LABELS_NOT_LABELS = { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - - def prepare(self, reactor, clock, homeserver): - self.user_id = self.register_user("test", "test") - self.tok = self.login("test", "test") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - def test_context_filter_labels(self): - """Test that we can filter by a label on a /context request.""" - event_id = self._send_labelled_messages_in_room() - - channel = self.make_request( - "GET", - "/rooms/%s/context/%s?filter=%s" - % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual( - len(events_before), 1, [event["content"] for event in events_before] - ) - self.assertEqual( - events_before[0]["content"]["body"], "with right label", events_before[0] - ) - - events_after = channel.json_body["events_before"] - - self.assertEqual( - len(events_after), 1, [event["content"] for event in events_after] - ) - self.assertEqual( - events_after[0]["content"]["body"], "with right label", events_after[0] - ) - - def test_context_filter_not_labels(self): - """Test that we can filter by the absence of a label on a /context request.""" - event_id = self._send_labelled_messages_in_room() - - channel = self.make_request( - "GET", - "/rooms/%s/context/%s?filter=%s" - % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual( - len(events_before), 1, [event["content"] for event in events_before] - ) - self.assertEqual( - events_before[0]["content"]["body"], "without label", events_before[0] - ) - - events_after = channel.json_body["events_after"] - - self.assertEqual( - len(events_after), 2, [event["content"] for event in events_after] - ) - self.assertEqual( - events_after[0]["content"]["body"], "with wrong label", events_after[0] - ) - self.assertEqual( - events_after[1]["content"]["body"], "with two wrong labels", events_after[1] - ) - - def test_context_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label on a - /context request. - """ - event_id = self._send_labelled_messages_in_room() - - channel = self.make_request( - "GET", - "/rooms/%s/context/%s?filter=%s" - % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual( - len(events_before), 0, [event["content"] for event in events_before] - ) - - events_after = channel.json_body["events_after"] - - self.assertEqual( - len(events_after), 1, [event["content"] for event in events_after] - ) - self.assertEqual( - events_after[0]["content"]["body"], "with wrong label", events_after[0] - ) - - def test_messages_filter_labels(self): - """Test that we can filter by a label on a /messages request.""" - self._send_labelled_messages_in_room() - - token = "s0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), - ) - - events = channel.json_body["chunk"] - - self.assertEqual(len(events), 2, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - - def test_messages_filter_not_labels(self): - """Test that we can filter by the absence of a label on a /messages request.""" - self._send_labelled_messages_in_room() - - token = "s0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), - ) - - events = channel.json_body["chunk"] - - self.assertEqual(len(events), 4, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "without label", events[0]) - self.assertEqual(events[1]["content"]["body"], "without label", events[1]) - self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2]) - self.assertEqual( - events[3]["content"]["body"], "with two wrong labels", events[3] - ) - - def test_messages_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label on a - /messages request. - """ - self._send_labelled_messages_in_room() - - token = "s0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % ( - self.room_id, - self.tok, - token, - json.dumps(self.FILTER_LABELS_NOT_LABELS), - ), - ) - - events = channel.json_body["chunk"] - - self.assertEqual(len(events), 1, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - - def test_search_filter_labels(self): - """Test that we can filter by a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS, - } - } - } - ) - - self._send_labelled_messages_in_room() - - channel = self.make_request( - "POST", "/search?access_token=%s" % self.tok, request_data - ) - - results = channel.json_body["search_categories"]["room_events"]["results"] - - self.assertEqual( - len(results), - 2, - [result["result"]["content"] for result in results], - ) - self.assertEqual( - results[0]["result"]["content"]["body"], - "with right label", - results[0]["result"]["content"]["body"], - ) - self.assertEqual( - results[1]["result"]["content"]["body"], - "with right label", - results[1]["result"]["content"]["body"], - ) - - def test_search_filter_not_labels(self): - """Test that we can filter by the absence of a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_NOT_LABELS, - } - } - } - ) - - self._send_labelled_messages_in_room() - - channel = self.make_request( - "POST", "/search?access_token=%s" % self.tok, request_data - ) - - results = channel.json_body["search_categories"]["room_events"]["results"] - - self.assertEqual( - len(results), - 4, - [result["result"]["content"] for result in results], - ) - self.assertEqual( - results[0]["result"]["content"]["body"], - "without label", - results[0]["result"]["content"]["body"], - ) - self.assertEqual( - results[1]["result"]["content"]["body"], - "without label", - results[1]["result"]["content"]["body"], - ) - self.assertEqual( - results[2]["result"]["content"]["body"], - "with wrong label", - results[2]["result"]["content"]["body"], - ) - self.assertEqual( - results[3]["result"]["content"]["body"], - "with two wrong labels", - results[3]["result"]["content"]["body"], - ) - - def test_search_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label on a - /search request. - """ - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS_NOT_LABELS, - } - } - } - ) - - self._send_labelled_messages_in_room() - - channel = self.make_request( - "POST", "/search?access_token=%s" % self.tok, request_data - ) - - results = channel.json_body["search_categories"]["room_events"]["results"] - - self.assertEqual( - len(results), - 1, - [result["result"]["content"] for result in results], - ) - self.assertEqual( - results[0]["result"]["content"]["body"], - "with wrong label", - results[0]["result"]["content"]["body"], - ) - - def _send_labelled_messages_in_room(self): - """Sends several messages to a room with different labels (or without any) to test - filtering by label. - Returns: - The ID of the event to use if we're testing filtering on /context. - """ - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=self.tok, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - tok=self.tok, - ) - - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - tok=self.tok, - ) - # Return this event's ID when we test filtering in /context requests. - event_id = res["event_id"] - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with wrong label", - EventContentFields.LABELS: ["#work"], - }, - tok=self.tok, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with two wrong labels", - EventContentFields.LABELS: ["#work", "#notfun"], - }, - tok=self.tok, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=self.tok, - ) - - return event_id - - -class ContextTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - account.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.user_id = self.register_user("user", "password") - self.tok = self.login("user", "password") - self.room_id = self.helper.create_room_as( - self.user_id, tok=self.tok, is_public=False - ) - - self.other_user_id = self.register_user("user2", "password") - self.other_tok = self.login("user2", "password") - - self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) - self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) - - def test_erased_sender(self): - """Test that an erasure request results in the requester's events being hidden - from any new member of the room. - """ - - # Send a bunch of events in the room. - - self.helper.send(self.room_id, "message 1", tok=self.tok) - self.helper.send(self.room_id, "message 2", tok=self.tok) - event_id = self.helper.send(self.room_id, "message 3", tok=self.tok)["event_id"] - self.helper.send(self.room_id, "message 4", tok=self.tok) - self.helper.send(self.room_id, "message 5", tok=self.tok) - - # Check that we can still see the messages before the erasure request. - - channel = self.make_request( - "GET", - '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' - % (self.room_id, event_id), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual(len(events_before), 2, events_before) - self.assertEqual( - events_before[0].get("content", {}).get("body"), - "message 2", - events_before[0], - ) - self.assertEqual( - events_before[1].get("content", {}).get("body"), - "message 1", - events_before[1], - ) - - self.assertEqual( - channel.json_body["event"].get("content", {}).get("body"), - "message 3", - channel.json_body["event"], - ) - - events_after = channel.json_body["events_after"] - - self.assertEqual(len(events_after), 2, events_after) - self.assertEqual( - events_after[0].get("content", {}).get("body"), - "message 4", - events_after[0], - ) - self.assertEqual( - events_after[1].get("content", {}).get("body"), - "message 5", - events_after[1], - ) - - # Deactivate the first account and erase the user's data. - - deactivate_account_handler = self.hs.get_deactivate_account_handler() - self.get_success( - deactivate_account_handler.deactivate_account( - self.user_id, True, create_requester(self.user_id) - ) - ) - - # Invite another user in the room. This is needed because messages will be - # pruned only if the user wasn't a member of the room when the messages were - # sent. - - invited_user_id = self.register_user("user3", "password") - invited_tok = self.login("user3", "password") - - self.helper.invite( - self.room_id, self.other_user_id, invited_user_id, tok=self.other_tok - ) - self.helper.join(self.room_id, invited_user_id, tok=invited_tok) - - # Check that a user that joined the room after the erasure request can't see - # the messages anymore. - - channel = self.make_request( - "GET", - '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' - % (self.room_id, event_id), - access_token=invited_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual(len(events_before), 2, events_before) - self.assertDictEqual(events_before[0].get("content"), {}, events_before[0]) - self.assertDictEqual(events_before[1].get("content"), {}, events_before[1]) - - self.assertDictEqual( - channel.json_body["event"].get("content"), {}, channel.json_body["event"] - ) - - events_after = channel.json_body["events_after"] - - self.assertEqual(len(events_after), 2, events_after) - self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) - self.assertEqual(events_after[1].get("content"), {}, events_after[1]) - - -class RoomAliasListTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - directory.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.room_owner = self.register_user("room_owner", "test") - self.room_owner_tok = self.login("room_owner", "test") - - self.room_id = self.helper.create_room_as( - self.room_owner, tok=self.room_owner_tok - ) - - def test_no_aliases(self): - res = self._get_aliases(self.room_owner_tok) - self.assertEqual(res["aliases"], []) - - def test_not_in_room(self): - self.register_user("user", "test") - user_tok = self.login("user", "test") - res = self._get_aliases(user_tok, expected_code=403) - self.assertEqual(res["errcode"], "M_FORBIDDEN") - - def test_admin_user(self): - alias1 = self._random_alias() - self._set_alias_via_directory(alias1) - - self.register_user("user", "test", admin=True) - user_tok = self.login("user", "test") - - res = self._get_aliases(user_tok) - self.assertEqual(res["aliases"], [alias1]) - - def test_with_aliases(self): - alias1 = self._random_alias() - alias2 = self._random_alias() - - self._set_alias_via_directory(alias1) - self._set_alias_via_directory(alias2) - - res = self._get_aliases(self.room_owner_tok) - self.assertEqual(set(res["aliases"]), {alias1, alias2}) - - def test_peekable_room(self): - alias1 = self._random_alias() - self._set_alias_via_directory(alias1) - - self.helper.send_state( - self.room_id, - EventTypes.RoomHistoryVisibility, - body={"history_visibility": "world_readable"}, - tok=self.room_owner_tok, - ) - - self.register_user("user", "test") - user_tok = self.login("user", "test") - - res = self._get_aliases(user_tok) - self.assertEqual(res["aliases"], [alias1]) - - def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: - """Calls the endpoint under test. returns the json response object.""" - channel = self.make_request( - "GET", - "/_matrix/client/r0/rooms/%s/aliases" % (self.room_id,), - access_token=access_token, - ) - self.assertEqual(channel.code, expected_code, channel.result) - res = channel.json_body - self.assertIsInstance(res, dict) - if expected_code == 200: - self.assertIsInstance(res["aliases"], list) - return res - - def _random_alias(self) -> str: - return RoomAlias(random_string(5), self.hs.hostname).to_string() - - def _set_alias_via_directory(self, alias: str, expected_code: int = 200): - url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) - - channel = self.make_request( - "PUT", url, request_data, access_token=self.room_owner_tok - ) - self.assertEqual(channel.code, expected_code, channel.result) - - -class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - directory.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.room_owner = self.register_user("room_owner", "test") - self.room_owner_tok = self.login("room_owner", "test") - - self.room_id = self.helper.create_room_as( - self.room_owner, tok=self.room_owner_tok - ) - - self.alias = "#alias:test" - self._set_alias_via_directory(self.alias) - - def _set_alias_via_directory(self, alias: str, expected_code: int = 200): - url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) - - channel = self.make_request( - "PUT", url, request_data, access_token=self.room_owner_tok - ) - self.assertEqual(channel.code, expected_code, channel.result) - - def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: - """Calls the endpoint under test. returns the json response object.""" - channel = self.make_request( - "GET", - "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), - access_token=self.room_owner_tok, - ) - self.assertEqual(channel.code, expected_code, channel.result) - res = channel.json_body - self.assertIsInstance(res, dict) - return res - - def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: - """Calls the endpoint under test. returns the json response object.""" - channel = self.make_request( - "PUT", - "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), - json.dumps(content), - access_token=self.room_owner_tok, - ) - self.assertEqual(channel.code, expected_code, channel.result) - res = channel.json_body - self.assertIsInstance(res, dict) - return res - - def test_canonical_alias(self): - """Test a basic alias message.""" - # There is no canonical alias to start with. - self._get_canonical_alias(expected_code=404) - - # Create an alias. - self._set_canonical_alias({"alias": self.alias}) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual(res, {"alias": self.alias}) - - # Now remove the alias. - self._set_canonical_alias({}) - - # There is an alias event, but it is empty. - res = self._get_canonical_alias() - self.assertEqual(res, {}) - - def test_alt_aliases(self): - """Test a canonical alias message with alt_aliases.""" - # Create an alias. - self._set_canonical_alias({"alt_aliases": [self.alias]}) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual(res, {"alt_aliases": [self.alias]}) - - # Now remove the alt_aliases. - self._set_canonical_alias({}) - - # There is an alias event, but it is empty. - res = self._get_canonical_alias() - self.assertEqual(res, {}) - - def test_alias_alt_aliases(self): - """Test a canonical alias message with an alias and alt_aliases.""" - # Create an alias. - self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) - - # Now remove the alias and alt_aliases. - self._set_canonical_alias({}) - - # There is an alias event, but it is empty. - res = self._get_canonical_alias() - self.assertEqual(res, {}) - - def test_partial_modify(self): - """Test removing only the alt_aliases.""" - # Create an alias. - self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) - - # Now remove the alt_aliases. - self._set_canonical_alias({"alias": self.alias}) - - # There is an alias event, but it is empty. - res = self._get_canonical_alias() - self.assertEqual(res, {"alias": self.alias}) - - def test_add_alias(self): - """Test removing only the alt_aliases.""" - # Create an additional alias. - second_alias = "#second:test" - self._set_alias_via_directory(second_alias) - - # Add the canonical alias. - self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) - - # Then add the second alias. - self._set_canonical_alias( - {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} - ) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual( - res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} - ) - - def test_bad_data(self): - """Invalid data for alt_aliases should cause errors.""" - self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) - self._set_canonical_alias({"alt_aliases": None}, expected_code=400) - self._set_canonical_alias({"alt_aliases": 0}, expected_code=400) - self._set_canonical_alias({"alt_aliases": 1}, expected_code=400) - self._set_canonical_alias({"alt_aliases": False}, expected_code=400) - self._set_canonical_alias({"alt_aliases": True}, expected_code=400) - self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) - - def test_bad_alias(self): - """An alias which does not point to the room raises a SynapseError.""" - self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) - self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py deleted file mode 100644 index b54b004733..0000000000 --- a/tests/rest/client/v1/test_typing.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests REST events for /rooms paths.""" - -from unittest.mock import Mock - -from synapse.rest.client import room -from synapse.types import UserID - -from tests import unittest - -PATH_PREFIX = "/_matrix/client/api/v1" - - -class RoomTypingTestCase(unittest.HomeserverTestCase): - """Tests /rooms/$room_id/typing/$user_id REST API.""" - - user_id = "@sid:red" - - user = UserID.from_string(user_id) - servlets = [room.register_servlets] - - def make_homeserver(self, reactor, clock): - - hs = self.setup_test_homeserver( - "red", - federation_http_client=None, - federation_client=Mock(), - ) - - self.event_source = hs.get_event_sources().sources["typing"] - - hs.get_federation_handler = Mock() - - async def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - async def _insert_client_ip(*args, **kwargs): - return None - - hs.get_datastore().insert_client_ip = _insert_client_ip - - return hs - - def prepare(self, reactor, clock, hs): - self.room_id = self.helper.create_room_as(self.user_id) - # Need another user to make notifications actually work - self.helper.join(self.room_id, user="@jim:red") - - def test_set_typing(self): - channel = self.make_request( - "PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - self.assertEquals(self.event_source.get_current_key(), 1) - events = self.get_success( - self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) - ) - self.assertEquals( - events[0], - [ - { - "type": "m.typing", - "room_id": self.room_id, - "content": {"user_ids": [self.user_id]}, - } - ], - ) - - def test_set_not_typing(self): - channel = self.make_request( - "PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - b'{"typing": false}', - ) - self.assertEquals(200, channel.code) - - def test_typing_timeout(self): - channel = self.make_request( - "PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - self.assertEquals(self.event_source.get_current_key(), 1) - - self.reactor.advance(36) - - self.assertEquals(self.event_source.get_current_key(), 2) - - channel = self.make_request( - "PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - self.assertEquals(self.event_source.get_current_key(), 3) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py deleted file mode 100644 index 954ad1a1fd..0000000000 --- a/tests/rest/client/v1/utils.py +++ /dev/null @@ -1,654 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import re -import time -import urllib.parse -from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union -from unittest.mock import patch - -import attr - -from twisted.web.resource import Resource -from twisted.web.server import Site - -from synapse.api.constants import Membership -from synapse.types import JsonDict - -from tests.server import FakeChannel, FakeSite, make_request -from tests.test_utils import FakeResponse -from tests.test_utils.html_parsers import TestHtmlParser - - -@attr.s -class RestHelper: - """Contains extra helper functions to quickly and clearly perform a given - REST action, which isn't the focus of the test. - """ - - hs = attr.ib() - site = attr.ib(type=Site) - auth_user_id = attr.ib() - - def create_room_as( - self, - room_creator: Optional[str] = None, - is_public: bool = True, - room_version: Optional[str] = None, - tok: Optional[str] = None, - expect_code: int = 200, - extra_content: Optional[Dict] = None, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, - ) -> str: - """ - Create a room. - - Args: - room_creator: The user ID to create the room with. - is_public: If True, the `visibility` parameter will be set to the - default (public). Otherwise, the `visibility` parameter will be set - to "private". - room_version: The room version to create the room as. Defaults to Synapse's - default room version. - tok: The access token to use in the request. - expect_code: The expected HTTP response code. - - Returns: - The ID of the newly created room. - """ - temp_id = self.auth_user_id - self.auth_user_id = room_creator - path = "/_matrix/client/r0/createRoom" - content = extra_content or {} - if not is_public: - content["visibility"] = "private" - if room_version: - content["room_version"] = room_version - if tok: - path = path + "?access_token=%s" % tok - - channel = make_request( - self.hs.get_reactor(), - self.site, - "POST", - path, - json.dumps(content).encode("utf8"), - custom_headers=custom_headers, - ) - - assert channel.result["code"] == b"%d" % expect_code, channel.result - self.auth_user_id = temp_id - - if expect_code == 200: - return channel.json_body["room_id"] - - def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): - self.change_membership( - room=room, - src=src, - targ=targ, - tok=tok, - membership=Membership.INVITE, - expect_code=expect_code, - ) - - def join(self, room=None, user=None, expect_code=200, tok=None): - self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.JOIN, - expect_code=expect_code, - ) - - def leave(self, room=None, user=None, expect_code=200, tok=None): - self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.LEAVE, - expect_code=expect_code, - ) - - def change_membership( - self, - room: str, - src: str, - targ: str, - membership: str, - extra_data: Optional[dict] = None, - tok: Optional[str] = None, - expect_code: int = 200, - ) -> None: - """ - Send a membership state event into a room. - - Args: - room: The ID of the room to send to - src: The mxid of the event sender - targ: The mxid of the event's target. The state key - membership: The type of membership event - extra_data: Extra information to include in the content of the event - tok: The user access token to use - expect_code: The expected HTTP response code - """ - temp_id = self.auth_user_id - self.auth_user_id = src - - path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ) - if tok: - path = path + "?access_token=%s" % tok - - data = {"membership": membership} - data.update(extra_data or {}) - - channel = make_request( - self.hs.get_reactor(), - self.site, - "PUT", - path, - json.dumps(data).encode("utf8"), - ) - - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( - expect_code, - int(channel.result["code"]), - channel.result["body"], - ) - - self.auth_user_id = temp_id - - def send( - self, - room_id, - body=None, - txn_id=None, - tok=None, - expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, - ): - if body is None: - body = "body_text_here" - - content = {"msgtype": "m.text", "body": body} - - return self.send_event( - room_id, - "m.room.message", - content, - txn_id, - tok, - expect_code, - custom_headers=custom_headers, - ) - - def send_event( - self, - room_id, - type, - content: Optional[dict] = None, - txn_id=None, - tok=None, - expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, - ): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) - - path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id) - if tok: - path = path + "?access_token=%s" % tok - - channel = make_request( - self.hs.get_reactor(), - self.site, - "PUT", - path, - json.dumps(content or {}).encode("utf8"), - custom_headers=custom_headers, - ) - - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( - expect_code, - int(channel.result["code"]), - channel.result["body"], - ) - - return channel.json_body - - def _read_write_state( - self, - room_id: str, - event_type: str, - body: Optional[Dict[str, Any]], - tok: str, - expect_code: int = 200, - state_key: str = "", - method: str = "GET", - ) -> Dict: - """Read or write some state from a given room - - Args: - room_id: - event_type: The type of state event - body: Body that is sent when making the request. The content of the state event. - If None, the request to the server will have an empty body - tok: The access token to use - expect_code: The HTTP code to expect in the response - state_key: - method: "GET" or "PUT" for reading or writing state, respectively - - Returns: - The response body from the server - - Raises: - AssertionError: if expect_code doesn't match the HTTP code we received - """ - path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( - room_id, - event_type, - state_key, - ) - if tok: - path = path + "?access_token=%s" % tok - - # Set request body if provided - content = b"" - if body is not None: - content = json.dumps(body).encode("utf8") - - channel = make_request(self.hs.get_reactor(), self.site, method, path, content) - - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( - expect_code, - int(channel.result["code"]), - channel.result["body"], - ) - - return channel.json_body - - def get_state( - self, - room_id: str, - event_type: str, - tok: str, - expect_code: int = 200, - state_key: str = "", - ): - """Gets some state from a room - - Args: - room_id: - event_type: The type of state event - tok: The access token to use - expect_code: The HTTP code to expect in the response - state_key: - - Returns: - The response body from the server - - Raises: - AssertionError: if expect_code doesn't match the HTTP code we received - """ - return self._read_write_state( - room_id, event_type, None, tok, expect_code, state_key, method="GET" - ) - - def send_state( - self, - room_id: str, - event_type: str, - body: Dict[str, Any], - tok: str, - expect_code: int = 200, - state_key: str = "", - ): - """Set some state in a room - - Args: - room_id: - event_type: The type of state event - body: Body that is sent when making the request. The content of the state event. - tok: The access token to use - expect_code: The HTTP code to expect in the response - state_key: - - Returns: - The response body from the server - - Raises: - AssertionError: if expect_code doesn't match the HTTP code we received - """ - return self._read_write_state( - room_id, event_type, body, tok, expect_code, state_key, method="PUT" - ) - - def upload_media( - self, - resource: Resource, - image_data: bytes, - tok: str, - filename: str = "test.png", - expect_code: int = 200, - ) -> dict: - """Upload a piece of test media to the media repo - Args: - resource: The resource that will handle the upload request - image_data: The image data to upload - tok: The user token to use during the upload - filename: The filename of the media to be uploaded - expect_code: The return code to expect from attempting to upload the media - """ - image_length = len(image_data) - path = "/_matrix/media/r0/upload?filename=%s" % (filename,) - channel = make_request( - self.hs.get_reactor(), - FakeSite(resource), - "POST", - path, - content=image_data, - access_token=tok, - custom_headers=[(b"Content-Length", str(image_length))], - ) - - assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( - expect_code, - int(channel.result["code"]), - channel.result["body"], - ) - - return channel.json_body - - def login_via_oidc(self, remote_user_id: str) -> JsonDict: - """Log in (as a new user) via OIDC - - Returns the result of the final token login. - - Requires that "oidc_config" in the homeserver config be set appropriately - (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a - "public_base_url". - - Also requires the login servlet and the OIDC callback resource to be mounted at - the normal places. - """ - client_redirect_url = "https://x" - channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) - - # expect a confirmation page - assert channel.code == 200, channel.result - - # fish the matrix login token out of the body of the confirmation page - m = re.search( - 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), - channel.text_body, - ) - assert m, channel.text_body - login_token = m.group(1) - - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token and device id. - channel = make_request( - self.hs.get_reactor(), - self.site, - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - assert channel.code == 200 - return channel.json_body - - def auth_via_oidc( - self, - user_info_dict: JsonDict, - client_redirect_url: Optional[str] = None, - ui_auth_session_id: Optional[str] = None, - ) -> FakeChannel: - """Perform an OIDC authentication flow via a mock OIDC provider. - - This can be used for either login or user-interactive auth. - - Starts by making a request to the relevant synapse redirect endpoint, which is - expected to serve a 302 to the OIDC provider. We then make a request to the - OIDC callback endpoint, intercepting the HTTP requests that will get sent back - to the OIDC provider. - - Requires that "oidc_config" in the homeserver config be set appropriately - (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a - "public_base_url". - - Also requires the login servlet and the OIDC callback resource to be mounted at - the normal places. - - Args: - user_info_dict: the remote userinfo that the OIDC provider should present. - Typically this should be '{"sub": ""}'. - client_redirect_url: for a login flow, the client redirect URL to pass to - the login redirect endpoint - ui_auth_session_id: if set, we will perform a UI Auth flow. The session id - of the UI auth. - - Returns: - A FakeChannel containing the result of calling the OIDC callback endpoint. - Note that the response code may be a 200, 302 or 400 depending on how things - went. - """ - - cookies = {} - - # if we're doing a ui auth, hit the ui auth redirect endpoint - if ui_auth_session_id: - # can't set the client redirect url for UI Auth - assert client_redirect_url is None - oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) - else: - # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) - - # we now have a URI for the OIDC IdP, but we skip that and go straight - # back to synapse's OIDC callback resource. However, we do need the "state" - # param that synapse passes to the IdP via query params, as well as the cookie - # that synapse passes to the client. - - oauth_uri_path, _ = oauth_uri.split("?", 1) - assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( - "unexpected SSO URI " + oauth_uri_path - ) - return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) - - def complete_oidc_auth( - self, - oauth_uri: str, - cookies: Mapping[str, str], - user_info_dict: JsonDict, - ) -> FakeChannel: - """Mock out an OIDC authentication flow - - Assumes that an OIDC auth has been initiated by one of initiate_sso_login or - initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to - Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get - sent back to the OIDC provider. - - Requires the OIDC callback resource to be mounted at the normal place. - - Args: - oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, - from initiate_sso_login or initiate_sso_ui_auth). - cookies: the cookies set by synapse's redirect endpoint, which will be - sent back to the callback endpoint. - user_info_dict: the remote userinfo that the OIDC provider should present. - Typically this should be '{"sub": ""}'. - - Returns: - A FakeChannel containing the result of calling the OIDC callback endpoint. - """ - _, oauth_uri_qs = oauth_uri.split("?", 1) - params = urllib.parse.parse_qs(oauth_uri_qs) - callback_uri = "%s?%s" % ( - urllib.parse.urlparse(params["redirect_uri"][0]).path, - urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), - ) - - # before we hit the callback uri, stub out some methods in the http client so - # that we don't have to handle full HTTPS requests. - # (expected url, json response) pairs, in the order we expect them. - expected_requests = [ - # first we get a hit to the token endpoint, which we tell to return - # a dummy OIDC access token - (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), - # and then one to the user_info endpoint, which returns our remote user id. - (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), - ] - - async def mock_req(method: str, uri: str, data=None, headers=None): - (expected_uri, resp_obj) = expected_requests.pop(0) - assert uri == expected_uri - resp = FakeResponse( - code=200, - phrase=b"OK", - body=json.dumps(resp_obj).encode("utf-8"), - ) - return resp - - with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): - # now hit the callback URI with the right params and a made-up code - channel = make_request( - self.hs.get_reactor(), - self.site, - "GET", - callback_uri, - custom_headers=[ - ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() - ], - ) - return channel - - def initiate_sso_login( - self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] - ) -> str: - """Make a request to the login-via-sso redirect endpoint, and return the target - - Assumes that exactly one SSO provider has been configured. Requires the login - servlet to be mounted. - - Args: - client_redirect_url: the client redirect URL to pass to the login redirect - endpoint - cookies: any cookies returned will be added to this dict - - Returns: - the URI that the client gets redirected to (ie, the SSO server) - """ - params = {} - if client_redirect_url: - params["redirectUrl"] = client_redirect_url - - # hit the redirect url (which should redirect back to the redirect url. This - # is the easiest way of figuring out what the Host header ought to be set to - # to keep Synapse happy. - channel = make_request( - self.hs.get_reactor(), - self.site, - "GET", - "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), - ) - assert channel.code == 302 - - # hit the redirect url again with the right Host header, which should now issue - # a cookie and redirect to the SSO provider. - location = channel.headers.getRawHeaders("Location")[0] - parts = urllib.parse.urlsplit(location) - channel = make_request( - self.hs.get_reactor(), - self.site, - "GET", - urllib.parse.urlunsplit(("", "") + parts[2:]), - custom_headers=[ - ("Host", parts[1]), - ], - ) - - assert channel.code == 302 - channel.extract_cookies(cookies) - return channel.headers.getRawHeaders("Location")[0] - - def initiate_sso_ui_auth( - self, ui_auth_session_id: str, cookies: MutableMapping[str, str] - ) -> str: - """Make a request to the ui-auth-via-sso endpoint, and return the target - - Assumes that exactly one SSO provider has been configured. Requires the - AuthRestServlet to be mounted. - - Args: - ui_auth_session_id: the session id of the UI auth - cookies: any cookies returned will be added to this dict - - Returns: - the URI that the client gets linked to (ie, the SSO server) - """ - sso_redirect_endpoint = ( - "/_matrix/client/r0/auth/m.login.sso/fallback/web?" - + urllib.parse.urlencode({"session": ui_auth_session_id}) - ) - # hit the redirect url (which will issue a cookie and state) - channel = make_request( - self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint - ) - # that should serve a confirmation page - assert channel.code == 200, channel.text_body - channel.extract_cookies(cookies) - - # parse the confirmation page to fish out the link. - p = TestHtmlParser() - p.feed(channel.text_body) - p.close() - assert len(p.links) == 1, "not exactly one link in confirmation page" - oauth_uri = p.links[0] - return oauth_uri - - -# an 'oidc_config' suitable for login_via_oidc. -TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" -TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" -TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" -TEST_OIDC_CONFIG = { - "enabled": True, - "discover": False, - "issuer": "https://issuer.test", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "scopes": ["profile"], - "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, - "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, - "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, -} diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py deleted file mode 100644 index b946fca8b3..0000000000 --- a/tests/rest/client/v2_alpha/test_account.py +++ /dev/null @@ -1,1000 +0,0 @@ -# Copyright 2015-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import os -import re -from email.parser import Parser -from typing import Optional - -import pkg_resources - -import synapse.rest.admin -from synapse.api.constants import LoginType, Membership -from synapse.api.errors import Codes, HttpResponseException -from synapse.appservice import ApplicationService -from synapse.rest.client import account, login, register, room -from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource - -from tests import unittest -from tests.server import FakeSite, make_request -from tests.unittest import override_config - - -class PasswordResetTestCase(unittest.HomeserverTestCase): - - servlets = [ - account.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - register.register_servlets, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - - # Email config. - config["email"] = { - "enable_notifs": False, - "template_dir": os.path.abspath( - pkg_resources.resource_filename("synapse", "res/templates") - ), - "smtp_host": "127.0.0.1", - "smtp_port": 20, - "require_transport_security": False, - "smtp_user": None, - "smtp_pass": None, - "notif_from": "test@example.com", - } - config["public_baseurl"] = "https://example.com" - - hs = self.setup_test_homeserver(config=config) - - async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) - - self.email_attempts = [] - hs.get_send_email_handler()._sendmail = sendmail - - return hs - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.submit_token_resource = PasswordResetSubmitTokenResource(hs) - - def test_basic_password_reset(self): - """Test basic password reset flow""" - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email = "test@example.com" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - client_secret = "foobar" - session_id = self._request_token(email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - self._reset_password(new_password, session_id, client_secret) - - # Assert we can log in with the new password - self.login("kermit", new_password) - - # Assert we can't log in with the old password - self.attempt_wrong_password_login("kermit", old_password) - - @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_email(self): - """Test that we ratelimit /requestToken for the same email.""" - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email = "test1@example.com" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - def reset(ip): - client_secret = "foobar" - session_id = self._request_token(email, client_secret, ip) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - self._reset_password(new_password, session_id, client_secret) - - self.email_attempts.clear() - - # We expect to be able to make three requests before getting rate - # limited. - # - # We change IPs to ensure that we're not being ratelimited due to the - # same IP - reset("127.0.0.1") - reset("127.0.0.2") - reset("127.0.0.3") - - with self.assertRaises(HttpResponseException) as cm: - reset("127.0.0.4") - - self.assertEqual(cm.exception.code, 429) - - def test_basic_password_reset_canonicalise_email(self): - """Test basic password reset flow - Request password reset with different spelling - """ - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email_profile = "test@example.com" - email_passwort_reset = "TEST@EXAMPLE.COM" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email_profile, - validated_at=0, - added_at=0, - ) - ) - - client_secret = "foobar" - session_id = self._request_token(email_passwort_reset, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - self._reset_password(new_password, session_id, client_secret) - - # Assert we can log in with the new password - self.login("kermit", new_password) - - # Assert we can't log in with the old password - self.attempt_wrong_password_login("kermit", old_password) - - def test_cant_reset_password_without_clicking_link(self): - """Test that we do actually need to click the link in the email""" - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email = "test@example.com" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - client_secret = "foobar" - session_id = self._request_token(email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - - # Attempt to reset password without clicking the link - self._reset_password(new_password, session_id, client_secret, expected_code=401) - - # Assert we can log in with the old password - self.login("kermit", old_password) - - # Assert we can't log in with the new password - self.attempt_wrong_password_login("kermit", new_password) - - def test_no_valid_token(self): - """Test that we do actually need to request a token and can't just - make a session up. - """ - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email = "test@example.com" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - client_secret = "foobar" - session_id = "weasle" - - # Attempt to reset password without even requesting an email - self._reset_password(new_password, session_id, client_secret, expected_code=401) - - # Assert we can log in with the old password - self.login("kermit", old_password) - - # Assert we can't log in with the new password - self.attempt_wrong_password_login("kermit", new_password) - - @unittest.override_config({"request_token_inhibit_3pid_errors": True}) - def test_password_reset_bad_email_inhibit_error(self): - """Test that triggering a password reset with an email address that isn't bound - to an account doesn't leak the lack of binding for that address if configured - that way. - """ - self.register_user("kermit", "monkey") - self.login("kermit", "monkey") - - email = "test@example.com" - - client_secret = "foobar" - session_id = self._request_token(email, client_secret) - - self.assertIsNotNone(session_id) - - def _request_token(self, email, client_secret, ip="127.0.0.1"): - channel = self.make_request( - "POST", - b"account/password/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, - client_ip=ip, - ) - - if channel.code != 200: - raise HttpResponseException( - channel.code, - channel.result["reason"], - channel.result["body"], - ) - - return channel.json_body["sid"] - - def _validate_token(self, link): - # Remove the host - path = link.replace("https://example.com", "") - - # Load the password reset confirmation page - channel = make_request( - self.reactor, - FakeSite(self.submit_token_resource), - "GET", - path, - shorthand=False, - ) - - self.assertEquals(200, channel.code, channel.result) - - # Now POST to the same endpoint, mimicking the same behaviour as clicking the - # password reset confirm button - - # Confirm the password reset - channel = make_request( - self.reactor, - FakeSite(self.submit_token_resource), - "POST", - path, - content=b"", - shorthand=False, - content_is_form=True, - ) - self.assertEquals(200, channel.code, channel.result) - - def _get_link_from_email(self): - assert self.email_attempts, "No emails have been sent" - - raw_msg = self.email_attempts[-1].decode("UTF-8") - mail = Parser().parsestr(raw_msg) - - text = None - for part in mail.walk(): - if part.get_content_type() == "text/plain": - text = part.get_payload(decode=True).decode("UTF-8") - break - - if not text: - self.fail("Could not find text portion of email to parse") - - match = re.search(r"https://example.com\S+", text) - assert match, "Could not find link in email" - - return match.group(0) - - def _reset_password( - self, new_password, session_id, client_secret, expected_code=200 - ): - channel = self.make_request( - "POST", - b"account/password", - { - "new_password": new_password, - "auth": { - "type": LoginType.EMAIL_IDENTITY, - "threepid_creds": { - "client_secret": client_secret, - "sid": session_id, - }, - }, - }, - ) - self.assertEquals(expected_code, channel.code, channel.result) - - -class DeactivateTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - account.register_servlets, - room.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - return self.hs - - def test_deactivate_account(self): - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test") - - self.deactivate(user_id, tok) - - store = self.hs.get_datastore() - - # Check that the user has been marked as deactivated. - self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) - - # Check that this access token has been invalidated. - channel = self.make_request("GET", "account/whoami", access_token=tok) - self.assertEqual(channel.code, 401) - - def test_pending_invites(self): - """Tests that deactivating a user rejects every pending invite for them.""" - store = self.hs.get_datastore() - - inviter_id = self.register_user("inviter", "test") - inviter_tok = self.login("inviter", "test") - - invitee_id = self.register_user("invitee", "test") - invitee_tok = self.login("invitee", "test") - - # Make @inviter:test invite @invitee:test in a new room. - room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok) - self.helper.invite( - room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok - ) - - # Make sure the invite is here. - pending_invites = self.get_success( - store.get_invited_rooms_for_local_user(invitee_id) - ) - self.assertEqual(len(pending_invites), 1, pending_invites) - self.assertEqual(pending_invites[0].room_id, room_id, pending_invites) - - # Deactivate @invitee:test. - self.deactivate(invitee_id, invitee_tok) - - # Check that the invite isn't there anymore. - pending_invites = self.get_success( - store.get_invited_rooms_for_local_user(invitee_id) - ) - self.assertEqual(len(pending_invites), 0, pending_invites) - - # Check that the membership of @invitee:test in the room is now "leave". - memberships = self.get_success( - store.get_rooms_for_local_user_where_membership_is( - invitee_id, [Membership.LEAVE] - ) - ) - self.assertEqual(len(memberships), 1, memberships) - self.assertEqual(memberships[0].room_id, room_id, memberships) - - def deactivate(self, user_id, tok): - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "test", - }, - "erase": False, - } - ) - channel = self.make_request( - "POST", "account/deactivate", request_data, access_token=tok - ) - self.assertEqual(channel.code, 200) - - -class WhoamiTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - account.register_servlets, - register.register_servlets, - ] - - def test_GET_whoami(self): - device_id = "wouldgohere" - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test", device_id=device_id) - - whoami = self.whoami(tok) - self.assertEqual(whoami, {"user_id": user_id, "device_id": device_id}) - - def test_GET_whoami_appservices(self): - user_id = "@as:test" - as_token = "i_am_an_app_service" - - appservice = ApplicationService( - as_token, - self.hs.config.server_name, - id="1234", - namespaces={"users": [{"regex": user_id, "exclusive": True}]}, - sender=user_id, - ) - self.hs.get_datastore().services_cache.append(appservice) - - whoami = self.whoami(as_token) - self.assertEqual(whoami, {"user_id": user_id}) - self.assertFalse(hasattr(whoami, "device_id")) - - def whoami(self, tok): - channel = self.make_request("GET", "account/whoami", {}, access_token=tok) - self.assertEqual(channel.code, 200) - return channel.json_body - - -class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): - - servlets = [ - account.register_servlets, - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - - # Email config. - config["email"] = { - "enable_notifs": False, - "template_dir": os.path.abspath( - pkg_resources.resource_filename("synapse", "res/templates") - ), - "smtp_host": "127.0.0.1", - "smtp_port": 20, - "require_transport_security": False, - "smtp_user": None, - "smtp_pass": None, - "notif_from": "test@example.com", - } - config["public_baseurl"] = "https://example.com" - - self.hs = self.setup_test_homeserver(config=config) - - async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) - - self.email_attempts = [] - self.hs.get_send_email_handler()._sendmail = sendmail - - return self.hs - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.user_id = self.register_user("kermit", "test") - self.user_id_tok = self.login("kermit", "test") - self.email = "test@example.com" - self.url_3pid = b"account/3pid" - - def test_add_valid_email(self): - self.get_success(self._add_email(self.email, self.email)) - - def test_add_valid_email_second_time(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - self.email, - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) - ) - - def test_add_valid_email_second_time_canonicalise(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - "TEST@EXAMPLE.COM", - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) - ) - - def test_add_email_no_at(self): - self.get_success( - self._request_token_invalid_email( - "address-without-at.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) - ) - - def test_add_email_two_at(self): - self.get_success( - self._request_token_invalid_email( - "foo@foo@test.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) - ) - - def test_add_email_bad_format(self): - self.get_success( - self._request_token_invalid_email( - "user@bad.example.net@good.example.com", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) - ) - - def test_add_email_domain_to_lower(self): - self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) - - def test_add_email_domain_with_umlaut(self): - self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")) - - def test_add_email_address_casefold(self): - self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) - - def test_address_trim(self): - self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) - - @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_ip(self): - """Tests that adding emails is ratelimited by IP""" - - # We expect to be able to set three emails before getting ratelimited. - self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) - self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) - self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) - - with self.assertRaises(HttpResponseException) as cm: - self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) - - self.assertEqual(cm.exception.code, 429) - - def test_add_email_if_disabled(self): - """Test adding email to profile when doing so is disallowed""" - self.hs.config.enable_3pid_changes = False - - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_delete_email(self): - """Test deleting an email from profile""" - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=self.user_id, - medium="email", - address=self.email, - validated_at=0, - added_at=0, - ) - ) - - channel = self.make_request( - "POST", - b"account/3pid/delete", - {"medium": "email", "address": self.email}, - access_token=self.user_id_tok, - ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_delete_email_if_disabled(self): - """Test deleting an email from profile when disallowed""" - self.hs.config.enable_3pid_changes = False - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=self.user_id, - medium="email", - address=self.email, - validated_at=0, - added_at=0, - ) - ) - - channel = self.make_request( - "POST", - b"account/3pid/delete", - {"medium": "email", "address": self.email}, - access_token=self.user_id_tok, - ) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) - - def test_cant_add_email_without_clicking_link(self): - """Test that we do actually need to click the link in the email""" - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - - # Attempt to add email without clicking the link - channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_no_valid_token(self): - """Test that we do actually need to request a token and can't just - make a session up. - """ - client_secret = "foobar" - session_id = "weasle" - - # Attempt to add email without even requesting an email - channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - @override_config({"next_link_domain_whitelist": None}) - def test_next_link(self): - """Tests a valid next_link parameter value with no whitelist (good case)""" - self._request_token( - "something@example.com", - "some_secret", - next_link="https://example.com/a/good/site", - expect_code=200, - ) - - @override_config({"next_link_domain_whitelist": None}) - def test_next_link_exotic_protocol(self): - """Tests using a esoteric protocol as a next_link parameter value. - Someone may be hosting a client on IPFS etc. - """ - self._request_token( - "something@example.com", - "some_secret", - next_link="some-protocol://abcdefghijklmopqrstuvwxyz", - expect_code=200, - ) - - @override_config({"next_link_domain_whitelist": None}) - def test_next_link_file_uri(self): - """Tests next_link parameters cannot be file URI""" - # Attempt to use a next_link value that points to the local disk - self._request_token( - "something@example.com", - "some_secret", - next_link="file:///host/path", - expect_code=400, - ) - - @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) - def test_next_link_domain_whitelist(self): - """Tests next_link parameters must fit the whitelist if provided""" - - # Ensure not providing a next_link parameter still works - self._request_token( - "something@example.com", - "some_secret", - next_link=None, - expect_code=200, - ) - - self._request_token( - "something@example.com", - "some_secret", - next_link="https://example.com/some/good/page", - expect_code=200, - ) - - self._request_token( - "something@example.com", - "some_secret", - next_link="https://example.org/some/also/good/page", - expect_code=200, - ) - - self._request_token( - "something@example.com", - "some_secret", - next_link="https://bad.example.org/some/bad/page", - expect_code=400, - ) - - @override_config({"next_link_domain_whitelist": []}) - def test_empty_next_link_domain_whitelist(self): - """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially - disallowed - """ - self._request_token( - "something@example.com", - "some_secret", - next_link="https://example.com/a/page", - expect_code=400, - ) - - def _request_token( - self, - email: str, - client_secret: str, - next_link: Optional[str] = None, - expect_code: int = 200, - ) -> str: - """Request a validation token to add an email address to a user's account - - Args: - email: The email address to validate - client_secret: A secret string - next_link: A link to redirect the user to after validation - expect_code: Expected return code of the call - - Returns: - The ID of the new threepid validation session - """ - body = {"client_secret": client_secret, "email": email, "send_attempt": 1} - if next_link: - body["next_link"] = next_link - - channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - body, - ) - - if channel.code != expect_code: - raise HttpResponseException( - channel.code, - channel.result["reason"], - channel.result["body"], - ) - - return channel.json_body.get("sid") - - def _request_token_invalid_email( - self, - email, - expected_errcode, - expected_error, - client_secret="foobar", - ): - channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, - ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(expected_errcode, channel.json_body["errcode"]) - self.assertEqual(expected_error, channel.json_body["error"]) - - def _validate_token(self, link): - # Remove the host - path = link.replace("https://example.com", "") - - channel = self.make_request("GET", path, shorthand=False) - self.assertEquals(200, channel.code, channel.result) - - def _get_link_from_email(self): - assert self.email_attempts, "No emails have been sent" - - raw_msg = self.email_attempts[-1].decode("UTF-8") - mail = Parser().parsestr(raw_msg) - - text = None - for part in mail.walk(): - if part.get_content_type() == "text/plain": - text = part.get_payload(decode=True).decode("UTF-8") - break - - if not text: - self.fail("Could not find text portion of email to parse") - - match = re.search(r"https://example.com\S+", text) - assert match, "Could not find link in email" - - return match.group(0) - - def _add_email(self, request_email, expected_email): - """Test adding an email to profile""" - previous_email_attempts = len(self.email_attempts) - - client_secret = "foobar" - session_id = self._request_token(request_email, client_secret) - - self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) - link = self._get_link_from_email() - - self._validate_token(link) - - channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - - threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} - self.assertIn(expected_email, threepids) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py deleted file mode 100644 index cf5cfb910c..0000000000 --- a/tests/rest/client/v2_alpha/test_auth.py +++ /dev/null @@ -1,717 +0,0 @@ -# Copyright 2018 New Vector -# Copyright 2020-2021 The Matrix.org Foundation C.I.C -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional, Union - -from twisted.internet.defer import succeed - -import synapse.rest.admin -from synapse.api.constants import LoginType -from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client import account, auth, devices, login, register -from synapse.rest.synapse.client import build_synapse_client_resource_tree -from synapse.types import JsonDict, UserID - -from tests import unittest -from tests.handlers.test_oidc import HAS_OIDC -from tests.rest.client.v1.utils import TEST_OIDC_CONFIG -from tests.server import FakeChannel -from tests.unittest import override_config, skip_unless - - -class DummyRecaptchaChecker(UserInteractiveAuthChecker): - def __init__(self, hs): - super().__init__(hs) - self.recaptcha_attempts = [] - - def check_auth(self, authdict, clientip): - self.recaptcha_attempts.append((authdict, clientip)) - return succeed(True) - - -class FallbackAuthTests(unittest.HomeserverTestCase): - - servlets = [ - auth.register_servlets, - register.register_servlets, - ] - hijack_auth = False - - def make_homeserver(self, reactor, clock): - - config = self.default_config() - - config["enable_registration_captcha"] = True - config["recaptcha_public_key"] = "brokencake" - config["registrations_require_3pid"] = [] - - hs = self.setup_test_homeserver(config=config) - return hs - - def prepare(self, reactor, clock, hs): - self.recaptcha_checker = DummyRecaptchaChecker(hs) - auth_handler = hs.get_auth_handler() - auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - - def register(self, expected_response: int, body: JsonDict) -> FakeChannel: - """Make a register request.""" - channel = self.make_request("POST", "register", body) - - self.assertEqual(channel.code, expected_response) - return channel - - def recaptcha( - self, - session: str, - expected_post_response: int, - post_session: Optional[str] = None, - ) -> None: - """Get and respond to a fallback recaptcha. Returns the second request.""" - if post_session is None: - post_session = session - - channel = self.make_request( - "GET", "auth/m.login.recaptcha/fallback/web?session=" + session - ) - self.assertEqual(channel.code, 200) - - channel = self.make_request( - "POST", - "auth/m.login.recaptcha/fallback/web?session=" - + post_session - + "&g-recaptcha-response=a", - ) - self.assertEqual(channel.code, expected_post_response) - - # The recaptcha handler is called with the response given - attempts = self.recaptcha_checker.recaptcha_attempts - self.assertEqual(len(attempts), 1) - self.assertEqual(attempts[0][0]["response"], "a") - - def test_fallback_captcha(self): - """Ensure that fallback auth via a captcha works.""" - # Returns a 401 as per the spec - channel = self.register( - 401, - {"username": "user", "type": "m.login.password", "password": "bar"}, - ) - - # Grab the session - session = channel.json_body["session"] - # Assert our configured public key is being given - self.assertEqual( - channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" - ) - - # Complete the recaptcha step. - self.recaptcha(session, 200) - - # also complete the dummy auth - self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) - - # Now we should have fulfilled a complete auth flow, including - # the recaptcha fallback step, we can then send a - # request to the register API with the session in the authdict. - channel = self.register(200, {"auth": {"session": session}}) - - # We're given a registered user. - self.assertEqual(channel.json_body["user_id"], "@user:test") - - def test_complete_operation_unknown_session(self): - """ - Attempting to mark an invalid session as complete should error. - """ - # Make the initial request to register. (Later on a different password - # will be used.) - # Returns a 401 as per the spec - channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"} - ) - - # Grab the session - session = channel.json_body["session"] - # Assert our configured public key is being given - self.assertEqual( - channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" - ) - - # Attempt to complete the recaptcha step with an unknown session. - # This results in an error. - self.recaptcha(session, 400, session + "unknown") - - -class UIAuthTests(unittest.HomeserverTestCase): - servlets = [ - auth.register_servlets, - devices.register_servlets, - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - register.register_servlets, - ] - - def default_config(self): - config = super().default_config() - - # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns - # False, so synapse will see the requested uri as http://..., so using http in - # the public_baseurl stops Synapse trying to redirect to https. - config["public_baseurl"] = "http://synapse.test" - - if HAS_OIDC: - # we enable OIDC as a way of testing SSO flows - oidc_config = {} - oidc_config.update(TEST_OIDC_CONFIG) - oidc_config["allow_existing_users"] = True - config["oidc_config"] = oidc_config - - return config - - def create_resource_dict(self): - resource_dict = super().create_resource_dict() - resource_dict.update(build_synapse_client_resource_tree(self.hs)) - return resource_dict - - def prepare(self, reactor, clock, hs): - self.user_pass = "pass" - self.user = self.register_user("test", self.user_pass) - self.device_id = "dev1" - self.user_tok = self.login("test", self.user_pass, self.device_id) - - def delete_device( - self, - access_token: str, - device: str, - expected_response: int, - body: Union[bytes, JsonDict] = b"", - ) -> FakeChannel: - """Delete an individual device.""" - channel = self.make_request( - "DELETE", - "devices/" + device, - body, - access_token=access_token, - ) - - # Ensure the response is sane. - self.assertEqual(channel.code, expected_response) - - return channel - - def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel: - """Delete 1 or more devices.""" - # Note that this uses the delete_devices endpoint so that we can modify - # the payload half-way through some tests. - channel = self.make_request( - "POST", - "delete_devices", - body, - access_token=self.user_tok, - ) - - # Ensure the response is sane. - self.assertEqual(channel.code, expected_response) - - return channel - - def test_ui_auth(self): - """ - Test user interactive authentication outside of registration. - """ - # Attempt to delete this device. - # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) - - # Grab the session - session = channel.json_body["session"] - # Ensure that flows are what is expected. - self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) - - # Make another request providing the UI auth flow. - self.delete_device( - self.user_tok, - self.device_id, - 200, - { - "auth": { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": self.user}, - "password": self.user_pass, - "session": session, - }, - }, - ) - - def test_grandfathered_identifier(self): - """Check behaviour without "identifier" dict - - Synapse used to require clients to submit a "user" field for m.login.password - UIA - check that still works. - """ - - channel = self.delete_device(self.user_tok, self.device_id, 401) - session = channel.json_body["session"] - - # Make another request providing the UI auth flow. - self.delete_device( - self.user_tok, - self.device_id, - 200, - { - "auth": { - "type": "m.login.password", - "user": self.user, - "password": self.user_pass, - "session": session, - }, - }, - ) - - def test_can_change_body(self): - """ - The client dict can be modified during the user interactive authentication session. - - Note that it is not spec compliant to modify the client dict during a - user interactive authentication session, but many clients currently do. - - When Synapse is updated to be spec compliant, the call to re-use the - session ID should be rejected. - """ - # Create a second login. - self.login("test", self.user_pass, "dev2") - - # Attempt to delete the first device. - # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [self.device_id]}) - - # Grab the session - session = channel.json_body["session"] - # Ensure that flows are what is expected. - self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) - - # Make another request providing the UI auth flow, but try to delete the - # second device. - self.delete_devices( - 200, - { - "devices": ["dev2"], - "auth": { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": self.user}, - "password": self.user_pass, - "session": session, - }, - }, - ) - - def test_cannot_change_uri(self): - """ - The initial requested URI cannot be modified during the user interactive authentication session. - """ - # Create a second login. - self.login("test", self.user_pass, "dev2") - - # Attempt to delete the first device. - # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) - - # Grab the session - session = channel.json_body["session"] - # Ensure that flows are what is expected. - self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) - - # Make another request providing the UI auth flow, but try to delete the - # second device. This results in an error. - # - # This makes use of the fact that the device ID is embedded into the URL. - self.delete_device( - self.user_tok, - "dev2", - 403, - { - "auth": { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": self.user}, - "password": self.user_pass, - "session": session, - }, - }, - ) - - @unittest.override_config({"ui_auth": {"session_timeout": "5s"}}) - def test_can_reuse_session(self): - """ - The session can be reused if configured. - - Compare to test_cannot_change_uri. - """ - # Create a second and third login. - self.login("test", self.user_pass, "dev2") - self.login("test", self.user_pass, "dev3") - - # Attempt to delete a device. This works since the user just logged in. - self.delete_device(self.user_tok, "dev2", 200) - - # Move the clock forward past the validation timeout. - self.reactor.advance(6) - - # Deleting another devices throws the user into UI auth. - channel = self.delete_device(self.user_tok, "dev3", 401) - - # Grab the session - session = channel.json_body["session"] - # Ensure that flows are what is expected. - self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) - - # Make another request providing the UI auth flow. - self.delete_device( - self.user_tok, - "dev3", - 200, - { - "auth": { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": self.user}, - "password": self.user_pass, - "session": session, - }, - }, - ) - - # Make another request, but try to delete the first device. This works - # due to re-using the previous session. - # - # Note that *no auth* information is provided, not even a session iD! - self.delete_device(self.user_tok, self.device_id, 200) - - @skip_unless(HAS_OIDC, "requires OIDC") - @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_via_sso(self): - """Test a successful UI Auth flow via SSO - - This includes: - * hitting the UIA SSO redirect endpoint - * checking it serves a confirmation page which links to the OIDC provider - * calling back to the synapse oidc callback - * checking that the original operation succeeds - """ - - # log the user in - remote_user_id = UserID.from_string(self.user).localpart - login_resp = self.helper.login_via_oidc(remote_user_id) - self.assertEqual(login_resp["user_id"], self.user) - - # initiate a UI Auth process by attempting to delete the device - channel = self.delete_device(self.user_tok, self.device_id, 401) - - # check that SSO is offered - flows = channel.json_body["flows"] - self.assertIn({"stages": ["m.login.sso"]}, flows) - - # run the UIA-via-SSO flow - session_id = channel.json_body["session"] - channel = self.helper.auth_via_oidc( - {"sub": remote_user_id}, ui_auth_session_id=session_id - ) - - # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) - - # and now the delete request should succeed. - self.delete_device( - self.user_tok, - self.device_id, - 200, - body={"auth": {"session": session_id}}, - ) - - @skip_unless(HAS_OIDC, "requires OIDC") - @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_does_not_offer_password_for_sso_user(self): - login_resp = self.helper.login_via_oidc("username") - user_tok = login_resp["access_token"] - device_id = login_resp["device_id"] - - # now call the device deletion API: we should get the option to auth with SSO - # and not password. - channel = self.delete_device(user_tok, device_id, 401) - - flows = channel.json_body["flows"] - self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) - - def test_does_not_offer_sso_for_password_user(self): - channel = self.delete_device(self.user_tok, self.device_id, 401) - - flows = channel.json_body["flows"] - self.assertEqual(flows, [{"stages": ["m.login.password"]}]) - - @skip_unless(HAS_OIDC, "requires OIDC") - @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_offers_both_flows_for_upgraded_user(self): - """A user that had a password and then logged in with SSO should get both flows""" - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) - self.assertEqual(login_resp["user_id"], self.user) - - channel = self.delete_device(self.user_tok, self.device_id, 401) - - flows = channel.json_body["flows"] - # we have no particular expectations of ordering here - self.assertIn({"stages": ["m.login.password"]}, flows) - self.assertIn({"stages": ["m.login.sso"]}, flows) - self.assertEqual(len(flows), 2) - - @skip_unless(HAS_OIDC, "requires OIDC") - @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_fails_for_incorrect_sso_user(self): - """If the user tries to authenticate with the wrong SSO user, they get an error""" - # log the user in - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) - self.assertEqual(login_resp["user_id"], self.user) - - # start a UI Auth flow by attempting to delete a device - channel = self.delete_device(self.user_tok, self.device_id, 401) - - flows = channel.json_body["flows"] - self.assertIn({"stages": ["m.login.sso"]}, flows) - session_id = channel.json_body["session"] - - # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc( - {"sub": "wrong_user"}, ui_auth_session_id=session_id - ) - - # that should return a failure message - self.assertSubstring("We were unable to validate", channel.text_body) - - # ... and the delete op should now fail with a 403 - self.delete_device( - self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} - ) - - -class RefreshAuthTests(unittest.HomeserverTestCase): - servlets = [ - auth.register_servlets, - account.register_servlets, - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - register.register_servlets, - ] - hijack_auth = False - - def prepare(self, reactor, clock, hs): - self.user_pass = "pass" - self.user = self.register_user("test", self.user_pass) - - def test_login_issue_refresh_token(self): - """ - A login response should include a refresh_token only if asked. - """ - # Test login - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} - - login_without_refresh = self.make_request( - "POST", "/_matrix/client/r0/login", body - ) - self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) - self.assertNotIn("refresh_token", login_without_refresh.json_body) - - login_with_refresh = self.make_request( - "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", - body, - ) - self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) - self.assertIn("refresh_token", login_with_refresh.json_body) - self.assertIn("expires_in_ms", login_with_refresh.json_body) - - def test_register_issue_refresh_token(self): - """ - A register response should include a refresh_token only if asked. - """ - register_without_refresh = self.make_request( - "POST", - "/_matrix/client/r0/register", - { - "username": "test2", - "password": self.user_pass, - "auth": {"type": LoginType.DUMMY}, - }, - ) - self.assertEqual( - register_without_refresh.code, 200, register_without_refresh.result - ) - self.assertNotIn("refresh_token", register_without_refresh.json_body) - - register_with_refresh = self.make_request( - "POST", - "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true", - { - "username": "test3", - "password": self.user_pass, - "auth": {"type": LoginType.DUMMY}, - }, - ) - self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) - self.assertIn("refresh_token", register_with_refresh.json_body) - self.assertIn("expires_in_ms", register_with_refresh.json_body) - - def test_token_refresh(self): - """ - A refresh token can be used to issue a new access token. - """ - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} - login_response = self.make_request( - "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", - body, - ) - self.assertEqual(login_response.code, 200, login_response.result) - - refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) - self.assertIn("access_token", refresh_response.json_body) - self.assertIn("refresh_token", refresh_response.json_body) - self.assertIn("expires_in_ms", refresh_response.json_body) - - # The access and refresh tokens should be different from the original ones after refresh - self.assertNotEqual( - login_response.json_body["access_token"], - refresh_response.json_body["access_token"], - ) - self.assertNotEqual( - login_response.json_body["refresh_token"], - refresh_response.json_body["refresh_token"], - ) - - @override_config({"access_token_lifetime": "1m"}) - def test_refresh_token_expiration(self): - """ - The access token should have some time as specified in the config. - """ - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} - login_response = self.make_request( - "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", - body, - ) - self.assertEqual(login_response.code, 200, login_response.result) - self.assertApproximates( - login_response.json_body["expires_in_ms"], 60 * 1000, 100 - ) - - refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) - self.assertApproximates( - refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 - ) - - def test_refresh_token_invalidation(self): - """Refresh tokens are invalidated after first use of the next token. - - A refresh token is considered invalid if: - - it was already used at least once - - and either - - the next access token was used - - the next refresh token was used - - The chain of tokens goes like this: - - login -|-> first_refresh -> third_refresh (fails) - |-> second_refresh -> fifth_refresh - |-> fourth_refresh (fails) - """ - - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} - login_response = self.make_request( - "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", - body, - ) - self.assertEqual(login_response.code, 200, login_response.result) - - # This first refresh should work properly - first_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual( - first_refresh_response.code, 200, first_refresh_response.result - ) - - # This one as well, since the token in the first one was never used - second_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual( - second_refresh_response.code, 200, second_refresh_response.result - ) - - # This one should not, since the token from the first refresh is not valid anymore - third_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": first_refresh_response.json_body["refresh_token"]}, - ) - self.assertEqual( - third_refresh_response.code, 401, third_refresh_response.result - ) - - # The associated access token should also be invalid - whoami_response = self.make_request( - "GET", - "/_matrix/client/r0/account/whoami", - access_token=first_refresh_response.json_body["access_token"], - ) - self.assertEqual(whoami_response.code, 401, whoami_response.result) - - # But all other tokens should work (they will expire after some time) - for access_token in [ - second_refresh_response.json_body["access_token"], - login_response.json_body["access_token"], - ]: - whoami_response = self.make_request( - "GET", "/_matrix/client/r0/account/whoami", access_token=access_token - ) - self.assertEqual(whoami_response.code, 200, whoami_response.result) - - # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail - fourth_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual( - fourth_refresh_response.code, 403, fourth_refresh_response.result - ) - - # But refreshing from the last valid refresh token still works - fifth_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": second_refresh_response.json_body["refresh_token"]}, - ) - self.assertEqual( - fifth_refresh_response.code, 200, fifth_refresh_response.result - ) diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py deleted file mode 100644 index ac31e5ceaf..0000000000 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import synapse.rest.admin -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.rest.client import capabilities, login - -from tests import unittest -from tests.unittest import override_config - - -class CapabilitiesTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - capabilities.register_servlets, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.url = b"/_matrix/client/r0/capabilities" - hs = self.setup_test_homeserver() - self.config = hs.config - self.auth_handler = hs.get_auth_handler() - return hs - - def prepare(self, reactor, clock, hs): - self.localpart = "user" - self.password = "pass" - self.user = self.register_user(self.localpart, self.password) - - def test_check_auth_required(self): - channel = self.make_request("GET", self.url) - - self.assertEqual(channel.code, 401) - - def test_get_room_version_capabilities(self): - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - for room_version in capabilities["m.room_versions"]["available"].keys(): - self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) - - self.assertEqual( - self.config.default_room_version.identifier, - capabilities["m.room_versions"]["default"], - ) - - def test_get_change_password_capabilities_password_login(self): - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertTrue(capabilities["m.change_password"]["enabled"]) - - @override_config({"password_config": {"localdb_enabled": False}}) - def test_get_change_password_capabilities_localdb_disabled(self): - access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( - self.user, device_id=None, valid_until_ms=None - ) - ) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["m.change_password"]["enabled"]) - - @override_config({"password_config": {"enabled": False}}) - def test_get_change_password_capabilities_password_disabled(self): - access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( - self.user, device_id=None, valid_until_ms=None - ) - ) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["m.change_password"]["enabled"]) - - def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self): - """Test that per default msc3283 is disabled server returns `m.change_password`.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities) - self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities) - self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities) - - @override_config({"experimental_features": {"msc3283_enabled": True}}) - def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self): - """Test if msc3283 is enabled server returns capabilities.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) - - @override_config( - { - "enable_set_displayname": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) - def test_get_set_displayname_capabilities_displayname_disabled(self): - """Test if set displayname is disabled that the server responds it.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) - - @override_config( - { - "enable_set_avatar_url": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) - def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): - """Test if set avatar_url is disabled that the server responds it.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) - - @override_config( - { - "enable_3pid_changes": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) - def test_change_3pid_capabilities_3pid_disabled(self): - """Test if change 3pid is disabled that the server responds it.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) - - def test_get_does_not_include_msc3244_fields_by_default(self): - access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( - self.user, device_id=None, valid_until_ms=None - ) - ) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertNotIn( - "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] - ) - - @override_config({"experimental_features": {"msc3244_enabled": True}}) - def test_get_does_include_msc3244_fields_when_enabled(self): - access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( - self.user, device_id=None, valid_until_ms=None - ) - ) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - for details in capabilities["m.room_versions"][ - "org.matrix.msc3244.room_capabilities" - ].values(): - if details["preferred"] is not None: - self.assertTrue( - details["preferred"] in KNOWN_ROOM_VERSIONS, - str(details["preferred"]), - ) - - self.assertGreater(len(details["support"]), 0) - for room_version in details["support"]: - self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version)) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py deleted file mode 100644 index 475c6bed3d..0000000000 --- a/tests/rest/client/v2_alpha/test_filter.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer - -from synapse.api.errors import Codes -from synapse.rest.client import filter - -from tests import unittest - -PATH_PREFIX = "/_matrix/client/v2_alpha" - - -class FilterTestCase(unittest.HomeserverTestCase): - - user_id = "@apple:test" - hijack_auth = True - EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} - EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' - servlets = [filter.register_servlets] - - def prepare(self, reactor, clock, hs): - self.filtering = hs.get_filtering() - self.store = hs.get_datastore() - - def test_add_filter(self): - channel = self.make_request( - "POST", - "/_matrix/client/r0/user/%s/filter" % (self.user_id), - self.EXAMPLE_FILTER_JSON, - ) - - self.assertEqual(channel.result["code"], b"200") - self.assertEqual(channel.json_body, {"filter_id": "0"}) - filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) - self.pump() - self.assertEquals(filter.result, self.EXAMPLE_FILTER) - - def test_add_filter_for_other_user(self): - channel = self.make_request( - "POST", - "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), - self.EXAMPLE_FILTER_JSON, - ) - - self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) - - def test_add_filter_non_local_user(self): - _is_mine = self.hs.is_mine - self.hs.is_mine = lambda target_user: False - channel = self.make_request( - "POST", - "/_matrix/client/r0/user/%s/filter" % (self.user_id), - self.EXAMPLE_FILTER_JSON, - ) - - self.hs.is_mine = _is_mine - self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) - - def test_get_filter(self): - filter_id = defer.ensureDeferred( - self.filtering.add_user_filter( - user_localpart="apple", user_filter=self.EXAMPLE_FILTER - ) - ) - self.reactor.advance(1) - filter_id = filter_id.result - channel = self.make_request( - "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) - ) - - self.assertEqual(channel.result["code"], b"200") - self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) - - def test_get_filter_non_existant(self): - channel = self.make_request( - "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) - ) - - self.assertEqual(channel.result["code"], b"404") - self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) - - # Currently invalid params do not have an appropriate errcode - # in errors.py - def test_get_filter_invalid_id(self): - channel = self.make_request( - "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) - ) - - self.assertEqual(channel.result["code"], b"400") - - # No ID also returns an invalid_id error - def test_get_filter_no_id(self): - channel = self.make_request( - "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) - ) - - self.assertEqual(channel.result["code"], b"400") diff --git a/tests/rest/client/v2_alpha/test_keys.py b/tests/rest/client/v2_alpha/test_keys.py deleted file mode 100644 index d7fa635eae..0000000000 --- a/tests/rest/client/v2_alpha/test_keys.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License - -from http import HTTPStatus - -from synapse.api.errors import Codes -from synapse.rest import admin -from synapse.rest.client import keys, login - -from tests import unittest - - -class KeyQueryTestCase(unittest.HomeserverTestCase): - servlets = [ - keys.register_servlets, - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - ] - - def test_rejects_device_id_ice_key_outside_of_list(self): - self.register_user("alice", "wonderland") - alice_token = self.login("alice", "wonderland") - bob = self.register_user("bob", "uncle") - channel = self.make_request( - "POST", - "/_matrix/client/r0/keys/query", - { - "device_keys": { - bob: "device_id1", - }, - }, - alice_token, - ) - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.BAD_JSON, - channel.result, - ) - - def test_rejects_device_key_given_as_map_to_bool(self): - self.register_user("alice", "wonderland") - alice_token = self.login("alice", "wonderland") - bob = self.register_user("bob", "uncle") - channel = self.make_request( - "POST", - "/_matrix/client/r0/keys/query", - { - "device_keys": { - bob: { - "device_id1": True, - }, - }, - }, - alice_token, - ) - - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.BAD_JSON, - channel.result, - ) - - def test_requires_device_key(self): - """`device_keys` is required. We should complain if it's missing.""" - self.register_user("alice", "wonderland") - alice_token = self.login("alice", "wonderland") - channel = self.make_request( - "POST", - "/_matrix/client/r0/keys/query", - {}, - alice_token, - ) - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.BAD_JSON, - channel.result, - ) diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py deleted file mode 100644 index 3cf5871899..0000000000 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from synapse.api.constants import LoginType -from synapse.api.errors import Codes -from synapse.rest import admin -from synapse.rest.client import account, login, password_policy, register - -from tests import unittest - - -class PasswordPolicyTestCase(unittest.HomeserverTestCase): - """Tests the password policy feature and its compliance with MSC2000. - - When validating a password, Synapse does the necessary checks in this order: - - 1. Password is long enough - 2. Password contains digit(s) - 3. Password contains symbol(s) - 4. Password contains uppercase letter(s) - 5. Password contains lowercase letter(s) - - For each test below that checks whether a password triggers the right error code, - that test provides a password good enough to pass the previous tests, but not the - one it is currently testing (nor any test that comes afterward). - """ - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - register.register_servlets, - password_policy.register_servlets, - account.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.register_url = "/_matrix/client/r0/register" - self.policy = { - "enabled": True, - "minimum_length": 10, - "require_digit": True, - "require_symbol": True, - "require_lowercase": True, - "require_uppercase": True, - } - - config = self.default_config() - config["password_config"] = { - "policy": self.policy, - } - - hs = self.setup_test_homeserver(config=config) - return hs - - def test_get_policy(self): - """Tests if the /password_policy endpoint returns the configured policy.""" - - channel = self.make_request("GET", "/_matrix/client/r0/password_policy") - - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual( - channel.json_body, - { - "m.minimum_length": 10, - "m.require_digit": True, - "m.require_symbol": True, - "m.require_lowercase": True, - "m.require_uppercase": True, - }, - channel.result, - ) - - def test_password_too_short(self): - request_data = json.dumps({"username": "kermit", "password": "shorty"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_TOO_SHORT, - channel.result, - ) - - def test_password_no_digit(self): - request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_DIGIT, - channel.result, - ) - - def test_password_no_symbol(self): - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_SYMBOL, - channel.result, - ) - - def test_password_no_uppercase(self): - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_UPPERCASE, - channel.result, - ) - - def test_password_no_lowercase(self): - request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_LOWERCASE, - channel.result, - ) - - def test_password_compliant(self): - request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) - channel = self.make_request("POST", self.register_url, request_data) - - # Getting a 401 here means the password has passed validation and the server has - # responded with a list of registration flows. - self.assertEqual(channel.code, 401, channel.result) - - def test_password_change(self): - """This doesn't test every possible use case, only that hitting /account/password - triggers the password validation code. - """ - compliant_password = "C0mpl!antpassword" - not_compliant_password = "notcompliantpassword" - - user_id = self.register_user("kermit", compliant_password) - tok = self.login("kermit", compliant_password) - - request_data = json.dumps( - { - "new_password": not_compliant_password, - "auth": { - "password": compliant_password, - "type": LoginType.PASSWORD, - "user": user_id, - }, - } - ) - channel = self.make_request( - "POST", - "/_matrix/client/r0/account/password", - request_data, - access_token=tok, - ) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py deleted file mode 100644 index fecda037a5..0000000000 --- a/tests/rest/client/v2_alpha/test_register.py +++ /dev/null @@ -1,746 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import datetime -import json -import os - -import pkg_resources - -import synapse.rest.admin -from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType -from synapse.api.errors import Codes -from synapse.appservice import ApplicationService -from synapse.rest.client import account, account_validity, login, logout, register, sync - -from tests import unittest -from tests.unittest import override_config - - -class RegisterRestServletTestCase(unittest.HomeserverTestCase): - - servlets = [ - login.register_servlets, - register.register_servlets, - synapse.rest.admin.register_servlets, - ] - url = b"/_matrix/client/r0/register" - - def default_config(self): - config = super().default_config() - config["allow_guest_access"] = True - return config - - def test_POST_appservice_registration_valid(self): - user_id = "@as_user_kermit:test" - as_token = "i_am_an_app_service" - - appservice = ApplicationService( - as_token, - self.hs.config.server_name, - id="1234", - namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, - sender="@as:test", - ) - - self.hs.get_datastore().services_cache.append(appservice) - request_data = json.dumps( - {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) - - channel = self.make_request( - b"POST", self.url + b"?access_token=i_am_an_app_service", request_data - ) - - self.assertEquals(channel.result["code"], b"200", channel.result) - det_data = {"user_id": user_id, "home_server": self.hs.hostname} - self.assertDictContainsSubset(det_data, channel.json_body) - - def test_POST_appservice_registration_no_type(self): - as_token = "i_am_an_app_service" - - appservice = ApplicationService( - as_token, - self.hs.config.server_name, - id="1234", - namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, - sender="@as:test", - ) - - self.hs.get_datastore().services_cache.append(appservice) - request_data = json.dumps({"username": "as_user_kermit"}) - - channel = self.make_request( - b"POST", self.url + b"?access_token=i_am_an_app_service", request_data - ) - - self.assertEquals(channel.result["code"], b"400", channel.result) - - def test_POST_appservice_registration_invalid(self): - self.appservice = None # no application service exists - request_data = json.dumps( - {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) - channel = self.make_request( - b"POST", self.url + b"?access_token=i_am_an_app_service", request_data - ) - - self.assertEquals(channel.result["code"], b"401", channel.result) - - def test_POST_bad_password(self): - request_data = json.dumps({"username": "kermit", "password": 666}) - channel = self.make_request(b"POST", self.url, request_data) - - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid password") - - def test_POST_bad_username(self): - request_data = json.dumps({"username": 777, "password": "monkey"}) - channel = self.make_request(b"POST", self.url, request_data) - - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid username") - - def test_POST_user_valid(self): - user_id = "@kermit:test" - device_id = "frogfone" - params = { - "username": "kermit", - "password": "monkey", - "device_id": device_id, - "auth": {"type": LoginType.DUMMY}, - } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) - - det_data = { - "user_id": user_id, - "home_server": self.hs.hostname, - "device_id": device_id, - } - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertDictContainsSubset(det_data, channel.json_body) - - @override_config({"enable_registration": False}) - def test_POST_disabled_registration(self): - request_data = json.dumps({"username": "kermit", "password": "monkey"}) - self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) - - channel = self.make_request(b"POST", self.url, request_data) - - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Registration has been disabled") - self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") - - def test_POST_guest_registration(self): - self.hs.config.macaroon_secret_key = "test" - self.hs.config.allow_guest_access = True - - channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - - det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertDictContainsSubset(det_data, channel.json_body) - - def test_POST_disabled_guest_registration(self): - self.hs.config.allow_guest_access = False - - channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Guest access is disabled") - - @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting_guest(self): - for i in range(0, 6): - url = self.url + b"?kind=guest" - channel = self.make_request(b"POST", url, b"{}") - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - - channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - - self.assertEquals(channel.result["code"], b"200", channel.result) - - @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting(self): - for i in range(0, 6): - params = { - "username": "kermit" + str(i), - "password": "monkey", - "device_id": "frogfone", - "auth": {"type": LoginType.DUMMY}, - } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - - channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_advertised_flows(self): - channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) - flows = channel.json_body["flows"] - - # with the stock config, we only expect the dummy flow - self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows)) - - @unittest.override_config( - { - "public_baseurl": "https://test_server", - "enable_registration_captcha": True, - "user_consent": { - "version": "1", - "template_dir": "/", - "require_at_registration": True, - }, - "account_threepid_delegates": { - "email": "https://id_server", - "msisdn": "https://id_server", - }, - } - ) - def test_advertised_flows_captcha_and_terms_and_3pids(self): - channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) - flows = channel.json_body["flows"] - - self.assertCountEqual( - [ - ["m.login.recaptcha", "m.login.terms", "m.login.dummy"], - ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"], - ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"], - [ - "m.login.recaptcha", - "m.login.terms", - "m.login.msisdn", - "m.login.email.identity", - ], - ], - (f["stages"] for f in flows), - ) - - @unittest.override_config( - { - "public_baseurl": "https://test_server", - "registrations_require_3pid": ["email"], - "disable_msisdn_registration": True, - "email": { - "smtp_host": "mail_server", - "smtp_port": 2525, - "notif_from": "sender@host", - }, - } - ) - def test_advertised_flows_no_msisdn_email_required(self): - channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) - flows = channel.json_body["flows"] - - # with the stock config, we expect all four combinations of 3pid - self.assertCountEqual( - [["m.login.email.identity"]], (f["stages"] for f in flows) - ) - - @unittest.override_config( - { - "request_token_inhibit_3pid_errors": True, - "public_baseurl": "https://test_server", - "email": { - "smtp_host": "mail_server", - "smtp_port": 2525, - "notif_from": "sender@host", - }, - } - ) - def test_request_token_existing_email_inhibit_error(self): - """Test that requesting a token via this endpoint doesn't leak existing - associations if configured that way. - """ - user_id = self.register_user("kermit", "monkey") - self.login("kermit", "monkey") - - email = "test@example.com" - - # Add a threepid - self.get_success( - self.hs.get_datastore().user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - channel = self.make_request( - "POST", - b"register/email/requestToken", - {"client_secret": "foobar", "email": email, "send_attempt": 1}, - ) - self.assertEquals(200, channel.code, channel.result) - - self.assertIsNotNone(channel.json_body.get("sid")) - - @unittest.override_config( - { - "public_baseurl": "https://test_server", - "email": { - "smtp_host": "mail_server", - "smtp_port": 2525, - "notif_from": "sender@host", - }, - } - ) - def test_reject_invalid_email(self): - """Check that bad emails are rejected""" - - # Test for email with multiple @ - channel = self.make_request( - "POST", - b"register/email/requestToken", - {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, - ) - self.assertEquals(400, channel.code, channel.result) - # Check error to ensure that we're not erroring due to a bug in the test. - self.assertEquals( - channel.json_body, - {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, - ) - - # Test for email with no @ - channel = self.make_request( - "POST", - b"register/email/requestToken", - {"client_secret": "foobar", "email": "email", "send_attempt": 1}, - ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( - channel.json_body, - {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, - ) - - # Test for super long email - email = "a@" + "a" * 1000 - channel = self.make_request( - "POST", - b"register/email/requestToken", - {"client_secret": "foobar", "email": email, "send_attempt": 1}, - ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( - channel.json_body, - {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, - ) - - -class AccountValidityTestCase(unittest.HomeserverTestCase): - - servlets = [ - register.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - sync.register_servlets, - logout.register_servlets, - account_validity.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - # Test for account expiring after a week. - config["enable_registration"] = True - config["account_validity"] = { - "enabled": True, - "period": 604800000, # Time in ms for 1 week - } - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def test_validity_period(self): - self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - # The specific endpoint doesn't matter, all we need is an authenticated - # endpoint. - channel = self.make_request(b"GET", "/sync", access_token=tok) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) - - channel = self.make_request(b"GET", "/sync", access_token=tok) - - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( - channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result - ) - - def test_manual_renewal(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) - - # If we register the admin user at the beginning of the test, it will - # expire at the same time as the normal user and the renewal request - # will be denied. - self.register_user("admin", "adminpassword", admin=True) - admin_tok = self.login("admin", "adminpassword") - - url = "/_synapse/admin/v1/account_validity/validity" - params = {"user_id": user_id} - request_data = json.dumps(params) - channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # The specific endpoint doesn't matter, all we need is an authenticated - # endpoint. - channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_manual_expire(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - self.register_user("admin", "adminpassword", admin=True) - admin_tok = self.login("admin", "adminpassword") - - url = "/_synapse/admin/v1/account_validity/validity" - params = { - "user_id": user_id, - "expiration_ts": 0, - "enable_renewal_emails": False, - } - request_data = json.dumps(params) - channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # The specific endpoint doesn't matter, all we need is an authenticated - # endpoint. - channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( - channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result - ) - - def test_logging_out_expired_user(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - self.register_user("admin", "adminpassword", admin=True) - admin_tok = self.login("admin", "adminpassword") - - url = "/_synapse/admin/v1/account_validity/validity" - params = { - "user_id": user_id, - "expiration_ts": 0, - "enable_renewal_emails": False, - } - request_data = json.dumps(params) - channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Try to log the user out - channel = self.make_request(b"POST", "/logout", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Log the user in again (allowed for expired accounts) - tok = self.login("kermit", "monkey") - - # Try to log out all of the user's sessions - channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - -class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): - - servlets = [ - register.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - sync.register_servlets, - account_validity.register_servlets, - account.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - - # Test for account expiring after a week and renewal emails being sent 2 - # days before expiry. - config["enable_registration"] = True - config["account_validity"] = { - "enabled": True, - "period": 604800000, # Time in ms for 1 week - "renew_at": 172800000, # Time in ms for 2 days - "renew_by_email_enabled": True, - "renew_email_subject": "Renew your account", - "account_renewed_html_path": "account_renewed.html", - "invalid_token_html_path": "invalid_token.html", - } - - # Email config. - - config["email"] = { - "enable_notifs": True, - "template_dir": os.path.abspath( - pkg_resources.resource_filename("synapse", "res/templates") - ), - "expiry_template_html": "notice_expiry.html", - "expiry_template_text": "notice_expiry.txt", - "notif_template_html": "notif_mail.html", - "notif_template_text": "notif_mail.txt", - "smtp_host": "127.0.0.1", - "smtp_port": 20, - "require_transport_security": False, - "smtp_user": None, - "smtp_pass": None, - "notif_from": "test@example.com", - } - config["public_baseurl"] = "aaa" - - self.hs = self.setup_test_homeserver(config=config) - - async def sendmail(*args, **kwargs): - self.email_attempts.append((args, kwargs)) - - self.email_attempts = [] - self.hs.get_send_email_handler()._sendmail = sendmail - - self.store = self.hs.get_datastore() - - return self.hs - - def test_renewal_email(self): - self.email_attempts = [] - - (user_id, tok) = self.create_user() - - # Move 5 days forward. This should trigger a renewal email to be sent. - self.reactor.advance(datetime.timedelta(days=5).total_seconds()) - self.assertEqual(len(self.email_attempts), 1) - - # Retrieving the URL from the email is too much pain for now, so we - # retrieve the token from the DB. - renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) - url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token - channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Check that we're getting HTML back. - content_type = channel.headers.getRawHeaders(b"Content-Type") - self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) - - # Check that the HTML we're getting is the one we expect on a successful renewal. - expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id)) - expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render( - expiration_ts=expiration_ts - ) - self.assertEqual( - channel.result["body"], expected_html.encode("utf8"), channel.result - ) - - # Move 1 day forward. Try to renew with the same token again. - url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token - channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Check that we're getting HTML back. - content_type = channel.headers.getRawHeaders(b"Content-Type") - self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) - - # Check that the HTML we're getting is the one we expect when reusing a - # token. The account expiration date should not have changed. - expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render( - expiration_ts=expiration_ts - ) - self.assertEqual( - channel.result["body"], expected_html.encode("utf8"), channel.result - ) - - # Move 3 days forward. If the renewal failed, every authed request with - # our access token should be denied from now, otherwise they should - # succeed. - self.reactor.advance(datetime.timedelta(days=3).total_seconds()) - channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_renewal_invalid_token(self): - # Hit the renewal endpoint with an invalid token and check that it behaves as - # expected, i.e. that it responds with 404 Not Found and the correct HTML. - url = "/_matrix/client/unstable/account_validity/renew?token=123" - channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"404", channel.result) - - # Check that we're getting HTML back. - content_type = channel.headers.getRawHeaders(b"Content-Type") - self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) - - # Check that the HTML we're getting is the one we expect when using an - # invalid/unknown token. - expected_html = ( - self.hs.config.account_validity.account_validity_invalid_token_template.render() - ) - self.assertEqual( - channel.result["body"], expected_html.encode("utf8"), channel.result - ) - - def test_manual_email_send(self): - self.email_attempts = [] - - (user_id, tok) = self.create_user() - channel = self.make_request( - b"POST", - "/_matrix/client/unstable/account_validity/send_mail", - access_token=tok, - ) - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.assertEqual(len(self.email_attempts), 1) - - def test_deactivated_user(self): - self.email_attempts = [] - - (user_id, tok) = self.create_user() - - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "monkey", - }, - "erase": False, - } - ) - channel = self.make_request( - "POST", "account/deactivate", request_data, access_token=tok - ) - self.assertEqual(channel.code, 200) - - self.reactor.advance(datetime.timedelta(days=8).total_seconds()) - - self.assertEqual(len(self.email_attempts), 0) - - def create_user(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - # We need to manually add an email address otherwise the handler will do - # nothing. - now = self.hs.get_clock().time_msec() - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address="kermit@example.com", - validated_at=now, - added_at=now, - ) - ) - return user_id, tok - - def test_manual_email_send_expired_account(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - # We need to manually add an email address otherwise the handler will do - # nothing. - now = self.hs.get_clock().time_msec() - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address="kermit@example.com", - validated_at=now, - added_at=now, - ) - ) - - # Make the account expire. - self.reactor.advance(datetime.timedelta(days=8).total_seconds()) - - # Ignore all emails sent by the automatic background task and only focus on the - # ones sent manually. - self.email_attempts = [] - - # Test that we're still able to manually trigger a mail to be sent. - channel = self.make_request( - b"POST", - "/_matrix/client/unstable/account_validity/send_mail", - access_token=tok, - ) - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.assertEqual(len(self.email_attempts), 1) - - -class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): - - servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] - - def make_homeserver(self, reactor, clock): - self.validity_period = 10 - self.max_delta = self.validity_period * 10.0 / 100.0 - - config = self.default_config() - - config["enable_registration"] = True - config["account_validity"] = {"enabled": False} - - self.hs = self.setup_test_homeserver(config=config) - - # We need to set these directly, instead of in the homeserver config dict above. - # This is due to account validity-related config options not being read by - # Synapse when account_validity.enabled is False. - self.hs.get_datastore()._account_validity_period = self.validity_period - self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta - - self.store = self.hs.get_datastore() - - return self.hs - - def test_background_job(self): - """ - Tests the same thing as test_background_job, except that it sets the - startup_job_max_delta parameter and checks that the expiration date is within the - allowed range. - """ - user_id = self.register_user("kermit_delta", "user") - - self.hs.config.account_validity.startup_job_max_delta = self.max_delta - - now_ms = self.hs.get_clock().time_msec() - self.get_success(self.store._set_expiration_date_when_missing()) - - res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) - - self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) - self.assertLessEqual(res, now_ms + self.validity_period) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py deleted file mode 100644 index 02b5e9a8d0..0000000000 --- a/tests/rest/client/v2_alpha/test_relations.py +++ /dev/null @@ -1,724 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import json -import urllib -from typing import Optional - -from synapse.api.constants import EventTypes, RelationTypes -from synapse.rest import admin -from synapse.rest.client import login, register, relations, room - -from tests import unittest - - -class RelationsTestCase(unittest.HomeserverTestCase): - servlets = [ - relations.register_servlets, - room.register_servlets, - login.register_servlets, - register.register_servlets, - admin.register_servlets_for_client_rest_resource, - ] - hijack_auth = False - - def make_homeserver(self, reactor, clock): - # We need to enable msc1849 support for aggregations - config = self.default_config() - config["experimental_msc1849_support_enabled"] = True - - # We enable frozen dicts as relations/edits change event contents, so we - # want to test that we don't modify the events in the caches. - config["use_frozen_dicts"] = True - - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor, clock, hs): - self.user_id, self.user_token = self._create_user("alice") - self.user2_id, self.user2_token = self._create_user("bob") - - self.room = self.helper.create_room_as(self.user_id, tok=self.user_token) - self.helper.join(self.room, user=self.user2_id, tok=self.user2_token) - res = self.helper.send(self.room, body="Hi!", tok=self.user_token) - self.parent_id = res["event_id"] - - def test_send_relation(self): - """Tests that sending a relation using the new /send_relation works - creates the right shape of event. - """ - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEquals(200, channel.code, channel.json_body) - - event_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, event_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assert_dict( - { - "type": "m.reaction", - "sender": self.user_id, - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "key": "👍", - "rel_type": RelationTypes.ANNOTATION, - } - }, - }, - channel.json_body, - ) - - def test_deny_membership(self): - """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) - self.assertEquals(400, channel.code, channel.json_body) - - def test_deny_double_react(self): - """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(400, channel.code, channel.json_body) - - def test_basic_paginate_relations(self): - """Tests that calling pagination API correctly the latest relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") - self.assertEquals(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" - % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # We expect to get back a single pagination result, which is the full - # relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( - {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, - channel.json_body["chunk"][0], - ) - - # We also expect to get the original event (the id of which is self.parent_id) - self.assertEquals( - channel.json_body["original_event"]["event_id"], self.parent_id - ) - - # Make sure next_batch has something in it that looks like it could be a - # valid token. - self.assertIsInstance( - channel.json_body.get("next_batch"), str, channel.json_body - ) - - def test_repeated_paginate_relations(self): - """Test that if we paginate using a limit and tokens then we get the - expected events. - """ - - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") - self.assertEquals(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - prev_token = None - found_event_ids = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" - % (self.room, self.parent_id, from_token), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - next_batch = channel.json_body.get("next_batch") - - self.assertNotEquals(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) - - def test_aggregation_pagination_groups(self): - """Test that we can paginate annotation groups correctly.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} - for key in itertools.chain.from_iterable( - itertools.repeat(key, num) for key, num in sent_groups.items() - ): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key=key, - access_token=access_tokens[idx], - ) - self.assertEquals(200, channel.code, channel.json_body) - - idx += 1 - idx %= len(access_tokens) - - prev_token = None - found_groups = {} - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" - % (self.room, self.parent_id, from_token), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - for groups in channel.json_body["chunk"]: - # We only expect reactions - self.assertEqual(groups["type"], "m.reaction", channel.json_body) - - # We should only see each key once - self.assertNotIn(groups["key"], found_groups, channel.json_body) - - found_groups[groups["key"]] = groups["count"] - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEquals(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - self.assertEquals(sent_groups, found_groups) - - def test_aggregation_pagination_within_group(self): - """Test that we can paginate within an annotation group.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="👍", - access_token=access_tokens[idx], - ) - self.assertEquals(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - idx += 1 - - # Also send a different type of reaction so that we test we don't see it - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) - - prev_token = None - found_event_ids = [] - encoded_key = urllib.parse.quote_plus("👍".encode()) - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s" - "/aggregations/%s/%s/m.reaction/%s?limit=1%s" - % ( - self.room, - self.parent_id, - RelationTypes.ANNOTATION, - encoded_key, - from_token, - ), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEquals(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) - - def test_aggregation(self): - """Test that annotations get correctly aggregated.""" - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s" - % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals( - channel.json_body, - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - ) - - def test_aggregation_redactions(self): - """Test that annotations get correctly aggregated after a redaction.""" - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) - to_redact_event_id = channel.json_body["event_id"] - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Now lets redact one of the 'a' reactions - channel = self.make_request( - "POST", - "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), - access_token=self.user_token, - content={}, - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s" - % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals( - channel.json_body, - {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, - ) - - def test_aggregation_must_be_annotation(self): - """Test that aggregations must be annotations.""" - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" - % (self.room, self.parent_id, RelationTypes.REPLACE), - access_token=self.user_token, - ) - self.assertEquals(400, channel.code, channel.json_body) - - def test_aggregation_get_event(self): - """Test that annotations and references get correctly bundled when - getting the parent event. - """ - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) - reply_1 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) - reply_2 = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals( - channel.json_body["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - RelationTypes.REFERENCE: { - "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] - }, - }, - ) - - def test_edit(self): - """Test that a simple edit works.""" - - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - ) - self.assertEquals(200, channel.code, channel.json_body) - - edit_event_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals(channel.json_body["content"], new_body) - - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) - - def test_multi_edit(self): - """Test that multiple edits, including attempts by people who - shouldn't be allowed, are correctly handled. - """ - - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": {"msgtype": "m.text", "body": "First edit"}, - }, - ) - self.assertEquals(200, channel.code, channel.json_body) - - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - ) - self.assertEquals(200, channel.code, channel.json_body) - - edit_event_id = channel.json_body["event_id"] - - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message.WRONG_TYPE", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, - }, - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals(channel.json_body["content"], new_body) - - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) - - def test_edit_reply(self): - """Test that editing a reply works.""" - - # Create a reply to edit. - channel = self._send_relation( - RelationTypes.REFERENCE, - "m.room.message", - content={"msgtype": "m.text", "body": "A reply!"}, - ) - self.assertEquals(200, channel.code, channel.json_body) - reply = channel.json_body["event_id"] - - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - parent_id=reply, - ) - self.assertEquals(200, channel.code, channel.json_body) - - edit_event_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, reply), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # We expect to see the new body in the dict, as well as the reference - # metadata sill intact. - self.assertDictContainsSubset(new_body, channel.json_body["content"]) - self.assertDictContainsSubset( - { - "m.relates_to": { - "event_id": self.parent_id, - "key": None, - "rel_type": "m.reference", - } - }, - channel.json_body["content"], - ) - - # We expect that the edit relation appears in the unsigned relations - # section. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) - - def test_relations_redaction_redacts_edits(self): - """Test that edits of an event are redacted when the original event - is redacted. - """ - # Send a new event - res = self.helper.send(self.room, body="Heyo!", tok=self.user_token) - original_event_id = res["event_id"] - - # Add a relation - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - parent_id=original_event_id, - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": {"msgtype": "m.text", "body": "First edit"}, - }, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Check the relation is returned - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEquals(len(channel.json_body["chunk"]), 1) - - # Redact the original event - channel = self.make_request( - "PUT", - "/rooms/%s/redact/%s/%s" - % (self.room, original_event_id, "test_relations_redaction_redacts_edits"), - access_token=self.user_token, - content="{}", - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Try to check for remaining m.replace relations - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Check that no relations are returned - self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) - - def test_aggregations_redaction_prevents_access_to_aggregations(self): - """Test that annotations of an event are redacted when the original event - is redacted. - """ - # Send a new event - res = self.helper.send(self.room, body="Hello!", tok=self.user_token) - original_event_id = res["event_id"] - - # Add a relation - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Redact the original - channel = self.make_request( - "PUT", - "/rooms/%s/redact/%s/%s" - % ( - self.room, - original_event_id, - "test_aggregations_redaction_prevents_access_to_aggregations", - ), - access_token=self.user_token, - content="{}", - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Check that aggregations returns zero - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) - - def _send_relation( - self, - relation_type, - event_type, - key=None, - content: Optional[dict] = None, - access_token=None, - parent_id=None, - ): - """Helper function to send a relation pointing at `self.parent_id` - - Args: - relation_type (str): One of `RelationTypes` - event_type (str): The type of the event to create - parent_id (str): The event_id this relation relates to. If None, then self.parent_id - key (str|None): The aggregation key used for m.annotation relation - type. - content(dict|None): The content of the created event. - access_token (str|None): The access token used to send the relation, - defaults to `self.user_token` - - Returns: - FakeChannel - """ - if not access_token: - access_token = self.user_token - - query = "" - if key: - query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8")) - - original_id = parent_id if parent_id else self.parent_id - - channel = self.make_request( - "POST", - "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" - % (self.room, original_id, relation_type, event_type, query), - json.dumps(content or {}).encode("utf-8"), - access_token=access_token, - ) - return channel - - def _create_user(self, localpart): - user_id = self.register_user(localpart, "abc123") - access_token = self.login(localpart, "abc123") - - return user_id, access_token diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py deleted file mode 100644 index ee6b0b9ebf..0000000000 --- a/tests/rest/client/v2_alpha/test_report_event.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2021 Callum Brown -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import synapse.rest.admin -from synapse.rest.client import login, report_event, room - -from tests import unittest - - -class ReportEventTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - report_event.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - self.other_user = self.register_user("user", "pass") - self.other_user_tok = self.login("user", "pass") - - self.room_id = self.helper.create_room_as( - self.other_user, tok=self.other_user_tok, is_public=True - ) - self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) - resp = self.helper.send(self.room_id, tok=self.admin_user_tok) - self.event_id = resp["event_id"] - self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" - - def test_reason_str_and_score_int(self): - data = {"reason": "this makes me sad", "score": -100} - self._assert_status(200, data) - - def test_no_reason(self): - data = {"score": 0} - self._assert_status(200, data) - - def test_no_score(self): - data = {"reason": "this makes me sad"} - self._assert_status(200, data) - - def test_no_reason_and_no_score(self): - data = {} - self._assert_status(200, data) - - def test_reason_int_and_score_str(self): - data = {"reason": 10, "score": "string"} - self._assert_status(400, data) - - def test_reason_zero_and_score_blank(self): - data = {"reason": 0, "score": ""} - self._assert_status(400, data) - - def test_reason_and_score_null(self): - data = {"reason": None, "score": None} - self._assert_status(400, data) - - def _assert_status(self, response_status, data): - channel = self.make_request( - "POST", - self.report_path, - json.dumps(data), - access_token=self.other_user_tok, - ) - self.assertEqual( - response_status, int(channel.result["code"]), msg=channel.result["body"] - ) diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/v2_alpha/test_sendtodevice.py deleted file mode 100644 index 6db7062a8e..0000000000 --- a/tests/rest/client/v2_alpha/test_sendtodevice.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.rest import admin -from synapse.rest.client import login, sendtodevice, sync - -from tests.unittest import HomeserverTestCase, override_config - - -class SendToDeviceTestCase(HomeserverTestCase): - servlets = [ - admin.register_servlets, - login.register_servlets, - sendtodevice.register_servlets, - sync.register_servlets, - ] - - def test_user_to_user(self): - """A to-device message from one user to another should get delivered""" - - user1 = self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") - - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - # send the message - test_msg = {"foo": "bar"} - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.test/1234", - content={"messages": {user2: {"d2": test_msg}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # check it appears - channel = self.make_request("GET", "/sync", access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - expected_result = { - "events": [ - { - "sender": user1, - "type": "m.test", - "content": test_msg, - } - ] - } - self.assertEqual(channel.json_body["to_device"], expected_result) - - # it should re-appear if we do another sync - channel = self.make_request("GET", "/sync", access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(channel.json_body["to_device"], expected_result) - - # it should *not* appear if we do an incremental sync - sync_token = channel.json_body["next_batch"] - channel = self.make_request( - "GET", f"/sync?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) - - @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_local_room_key_request(self): - """m.room_key_request has special-casing; test from local user""" - user1 = self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") - - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - # send three messages - for i in range(3): - chan = self.make_request( - "PUT", - f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}", - content={"messages": {user2: {"d2": {"idx": i}}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # now sync: we should get two of the three - channel = self.make_request("GET", "/sync", access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 2) - for i in range(2): - self.assertEqual( - msgs[i], - {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}}, - ) - sync_token = channel.json_body["next_batch"] - - # ... time passes - self.reactor.advance(1) - - # and we can send more messages - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.room_key_request/3", - content={"messages": {user2: {"d2": {"idx": 3}}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # ... which should arrive - channel = self.make_request( - "GET", f"/sync?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 1) - self.assertEqual( - msgs[0], - {"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}}, - ) - - @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_remote_room_key_request(self): - """m.room_key_request has special-casing; test from remote user""" - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - federation_registry = self.hs.get_federation_registry() - - # send three messages - for i in range(3): - self.get_success( - federation_registry.on_edu( - "m.direct_to_device", - "remote_server", - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "messages": {user2: {"d2": {"idx": i}}}, - "message_id": f"{i}", - }, - ) - ) - - # now sync: we should get two of the three - channel = self.make_request("GET", "/sync", access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 2) - for i in range(2): - self.assertEqual( - msgs[i], - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "content": {"idx": i}, - }, - ) - sync_token = channel.json_body["next_batch"] - - # ... time passes - self.reactor.advance(1) - - # and we can send more messages - self.get_success( - federation_registry.on_edu( - "m.direct_to_device", - "remote_server", - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "messages": {user2: {"d2": {"idx": 3}}}, - "message_id": "3", - }, - ) - ) - - # ... which should arrive - channel = self.make_request( - "GET", f"/sync?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 1) - self.assertEqual( - msgs[0], - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "content": {"idx": 3}, - }, - ) diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py deleted file mode 100644 index 283eccd53f..0000000000 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright 2020 Half-Shot -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import synapse.rest.admin -from synapse.rest.client import login, room, shared_rooms - -from tests import unittest -from tests.server import FakeChannel - - -class UserSharedRoomsTest(unittest.HomeserverTestCase): - """ - Tests the UserSharedRoomsServlet. - """ - - servlets = [ - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - shared_rooms.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["update_user_directory"] = True - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.handler = hs.get_user_directory_handler() - - def _get_shared_rooms(self, token, other_user) -> FakeChannel: - return self.make_request( - "GET", - "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" - % other_user, - access_token=token, - ) - - def test_shared_room_list_public(self): - """ - A room should show up in the shared list of rooms between two users - if it is public. - """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) - - def test_shared_room_list_private(self): - """ - A room should show up in the shared list of rooms between two users - if it is private. - """ - self._check_shared_rooms_with( - room_one_is_public=False, room_two_is_public=False - ) - - def test_shared_room_list_mixed(self): - """ - The shared room list between two users should contain both public and private - rooms. - """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) - - def _check_shared_rooms_with( - self, room_one_is_public: bool, room_two_is_public: bool - ): - """Checks that shared public or private rooms between two users appear in - their shared room lists - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - # Create a room. user1 invites user2, who joins - room_id_one = self.helper.create_room_as( - u1, is_public=room_one_is_public, tok=u1_token - ) - self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) - self.helper.join(room_id_one, user=u2, tok=u2_token) - - # Check shared rooms from user1's perspective. - # We should see the one room in common - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room_id_one) - - # Create another room and invite user2 to it - room_id_two = self.helper.create_room_as( - u1, is_public=room_two_is_public, tok=u1_token - ) - self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) - self.helper.join(room_id_two, user=u2, tok=u2_token) - - # Check shared rooms again. We should now see both rooms. - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 2) - for room_id_id in channel.json_body["joined"]: - self.assertIn(room_id_id, [room_id_one, room_id_two]) - - def test_shared_room_list_after_leave(self): - """ - A room should no longer be considered shared if the other - user has left it. - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - # Assert user directory is not empty - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room) - - self.helper.leave(room, user=u1, tok=u1_token) - - # Check user1's view of shared rooms with user2 - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) - - # Check user2's view of shared rooms with user1 - channel = self._get_shared_rooms(u2_token, u1) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py deleted file mode 100644 index 95be369d4b..0000000000 --- a/tests/rest/client/v2_alpha/test_sync.py +++ /dev/null @@ -1,686 +0,0 @@ -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json - -import synapse.rest.admin -from synapse.api.constants import ( - EventContentFields, - EventTypes, - ReadReceiptEventFields, - RelationTypes, -) -from synapse.rest.client import knock, login, read_marker, receipts, room, sync - -from tests import unittest -from tests.federation.transport.test_knocking import ( - KnockingStrippedStateEventHelperMixin, -) -from tests.server import TimedOutException -from tests.unittest import override_config - - -class FilterTestCase(unittest.HomeserverTestCase): - - user_id = "@apple:test" - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - sync.register_servlets, - ] - - def test_sync_argless(self): - channel = self.make_request("GET", "/sync") - - self.assertEqual(channel.code, 200) - self.assertIn("next_batch", channel.json_body) - - -class SyncFilterTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - sync.register_servlets, - ] - - def test_sync_filter_labels(self): - """Test that we can filter by a label.""" - sync_filter = json.dumps( - { - "room": { - "timeline": { - "types": [EventTypes.Message], - "org.matrix.labels": ["#fun"], - } - } - } - ) - - events = self._test_sync_filter_labels(sync_filter) - - self.assertEqual(len(events), 2, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - - def test_sync_filter_not_labels(self): - """Test that we can filter by the absence of a label.""" - sync_filter = json.dumps( - { - "room": { - "timeline": { - "types": [EventTypes.Message], - "org.matrix.not_labels": ["#fun"], - } - } - } - ) - - events = self._test_sync_filter_labels(sync_filter) - - self.assertEqual(len(events), 3, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "without label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) - self.assertEqual( - events[2]["content"]["body"], "with two wrong labels", events[2] - ) - - def test_sync_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label.""" - sync_filter = json.dumps( - { - "room": { - "timeline": { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - } - } - ) - - events = self._test_sync_filter_labels(sync_filter) - - self.assertEqual(len(events), 1, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - - def _test_sync_filter_labels(self, sync_filter): - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test") - - room_id = self.helper.create_room_as(user_id, tok=tok) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=tok, - ) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - tok=tok, - ) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with wrong label", - EventContentFields.LABELS: ["#work"], - }, - tok=tok, - ) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with two wrong labels", - EventContentFields.LABELS: ["#work", "#notfun"], - }, - tok=tok, - ) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=tok, - ) - - channel = self.make_request( - "GET", "/sync?filter=%s" % sync_filter, access_token=tok - ) - self.assertEqual(channel.code, 200, channel.result) - - return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] - - -class SyncTypingTests(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - sync.register_servlets, - ] - user_id = True - hijack_auth = False - - def test_sync_backwards_typing(self): - """ - If the typing serial goes backwards and the typing handler is then reset - (such as when the master restarts and sets the typing serial to 0), we - do not incorrectly return typing information that had a serial greater - than the now-reset serial. - """ - typing_url = "/rooms/%s/typing/%s?access_token=%s" - sync_url = "/sync?timeout=3000000&access_token=%s&since=%s" - - # Register the user who gets notified - user_id = self.register_user("user", "pass") - access_token = self.login("user", "pass") - - # Register the user who sends the message - other_user_id = self.register_user("otheruser", "pass") - other_access_token = self.login("otheruser", "pass") - - # Create a room - room = self.helper.create_room_as(user_id, tok=access_token) - - # Invite the other person - self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) - - # The other user joins - self.helper.join(room=room, user=other_user_id, tok=other_access_token) - - # The other user sends some messages - self.helper.send(room, body="Hi!", tok=other_access_token) - self.helper.send(room, body="There!", tok=other_access_token) - - # Start typing. - channel = self.make_request( - "PUT", - typing_url % (room, other_user_id, other_access_token), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,)) - self.assertEquals(200, channel.code) - next_batch = channel.json_body["next_batch"] - - # Stop typing. - channel = self.make_request( - "PUT", - typing_url % (room, other_user_id, other_access_token), - b'{"typing": false}', - ) - self.assertEquals(200, channel.code) - - # Start typing. - channel = self.make_request( - "PUT", - typing_url % (room, other_user_id, other_access_token), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - # Should return immediately - channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) - next_batch = channel.json_body["next_batch"] - - # Reset typing serial back to 0, as if the master had. - typing = self.hs.get_typing_handler() - typing._latest_room_serial = 0 - - # Since it checks the state token, we need some state to update to - # invalidate the stream token. - self.helper.send(room, body="There!", tok=other_access_token) - - channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) - next_batch = channel.json_body["next_batch"] - - # This should time out! But it does not, because our stream token is - # ahead, and therefore it's saying the typing (that we've actually - # already seen) is new, since it's got a token above our new, now-reset - # stream token. - channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) - next_batch = channel.json_body["next_batch"] - - # Clear the typing information, so that it doesn't think everything is - # in the future. - typing._reset() - - # Now it SHOULD fail as it never completes! - with self.assertRaises(TimedOutException): - self.make_request("GET", sync_url % (access_token, next_batch)) - - -class SyncKnockTestCase( - unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin -): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - knock.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user (used to create the room to knock on). - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room we'll knock on. - self.room_id = self.helper.create_room_as( - self.user_id, - is_public=False, - room_version="7", - tok=self.tok, - ) - - # Register the second user (used to knock on the room). - self.knocker = self.register_user("knocker", "monkey") - self.knocker_tok = self.login("knocker", "monkey") - - # Perform an initial sync for the knocking user. - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] - - # Set up some room state to test with. - self.expected_room_state = self.send_example_state_events_to_room( - hs, self.room_id, self.user_id - ) - - @override_config({"experimental_features": {"msc2403_enabled": True}}) - def test_knock_room_state(self): - """Tests that /sync returns state from a room after knocking on it.""" - # Knock on a room - channel = self.make_request( - "POST", - "/_matrix/client/r0/knock/%s" % (self.room_id,), - b"{}", - self.knocker_tok, - ) - self.assertEquals(200, channel.code, channel.result) - - # We expect to see the knock event in the stripped room state later - self.expected_room_state[EventTypes.Member] = { - "content": {"membership": "knock", "displayname": "knocker"}, - "state_key": "@knocker:test", - } - - # Check that /sync includes stripped state from the room - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.knocker_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Extract the stripped room state events from /sync - knock_entry = channel.json_body["rooms"]["knock"] - room_state_events = knock_entry[self.room_id]["knock_state"]["events"] - - # Validate that the knock membership event came last - self.assertEqual(room_state_events[-1]["type"], EventTypes.Member) - - # Validate the stripped room state events - self.check_knock_room_state_against_room_state( - room_state_events, self.expected_room_state - ) - - -class ReadReceiptsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - receipts.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - # Register the second user - self.user2 = self.register_user("kermit2", "monkey") - self.tok2 = self.login("kermit2", "monkey") - - # Join the second user - self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - - @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_hidden_read_receipts(self): - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a read receipt to tell the server the first user's message was read - body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") - channel = self.make_request( - "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), - body, - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - - # Test that the first user can't see the other user's hidden read receipt - self.assertEqual(self._get_read_receipt(), None) - - def test_read_receipt_with_empty_body(self): - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a read receipt for this message with an empty body - channel = self.make_request( - "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - - def _get_read_receipt(self): - """Syncs and returns the read receipt.""" - - # Checks if event is a read receipt - def is_read_receipt(event): - return event["type"] == "m.receipt" - - # Sync - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] - - # Return the read receipt - ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ - "ephemeral" - ]["events"] - return next(filter(is_read_receipt, ephemeral_events), None) - - -class UnreadMessagesTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - read_marker.register_servlets, - room.register_servlets, - sync.register_servlets, - receipts.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user (used to check the unread counts). - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room we'll check unread counts for. - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - # Register the second user (used to send events to the room). - self.user2 = self.register_user("kermit2", "monkey") - self.tok2 = self.login("kermit2", "monkey") - - # Change the power levels of the room so that the second user can send state - # events. - self.helper.send_state( - self.room_id, - EventTypes.PowerLevels, - { - "users": {self.user_id: 100, self.user2: 100}, - "users_default": 0, - "events": { - "m.room.name": 50, - "m.room.power_levels": 100, - "m.room.history_visibility": 100, - "m.room.canonical_alias": 50, - "m.room.avatar": 50, - "m.room.tombstone": 100, - "m.room.server_acl": 100, - "m.room.encryption": 100, - }, - "events_default": 0, - "state_default": 50, - "ban": 50, - "kick": 50, - "redact": 50, - "invite": 0, - }, - tok=self.tok, - ) - - def test_unread_counts(self): - """Tests that /sync returns the right value for the unread count (MSC2654).""" - - # Check that our own messages don't increase the unread count. - self.helper.send(self.room_id, "hello", tok=self.tok) - self._check_unread_count(0) - - # Join the new user and check that this doesn't increase the unread count. - self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - self._check_unread_count(0) - - # Check that the new user sending a message increases our unread count. - res = self.helper.send(self.room_id, "hello", tok=self.tok2) - self._check_unread_count(1) - - # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({"m.read": res["event_id"]}).encode("utf8") - channel = self.make_request( - "POST", - "/rooms/%s/read_markers" % self.room_id, - body, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check that the unread counter is back to 0. - self._check_unread_count(0) - - # Check that hidden read receipts don't break unread counts - res = self.helper.send(self.room_id, "hello", tok=self.tok2) - self._check_unread_count(1) - - # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") - channel = self.make_request( - "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), - body, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check that the unread counter is back to 0. - self._check_unread_count(0) - - # Check that room name changes increase the unread counter. - self.helper.send_state( - self.room_id, - "m.room.name", - {"name": "my super room"}, - tok=self.tok2, - ) - self._check_unread_count(1) - - # Check that room topic changes increase the unread counter. - self.helper.send_state( - self.room_id, - "m.room.topic", - {"topic": "welcome!!!"}, - tok=self.tok2, - ) - self._check_unread_count(2) - - # Check that encrypted messages increase the unread counter. - self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) - self._check_unread_count(3) - - # Check that custom events with a body increase the unread counter. - self.helper.send_event( - self.room_id, - "org.matrix.custom_type", - {"body": "hello"}, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that edits don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "body": "hello", - "msgtype": "m.text", - "m.relates_to": {"rel_type": RelationTypes.REPLACE}, - }, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that notices don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"body": "hello", "msgtype": "m.notice"}, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that tombstone events changes increase the unread counter. - self.helper.send_state( - self.room_id, - EventTypes.Tombstone, - {"replacement_room": "!someroom:test"}, - tok=self.tok2, - ) - self._check_unread_count(5) - - def _check_unread_count(self, expected_count: int): - """Syncs and compares the unread count with the expected value.""" - - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.tok, - ) - - self.assertEqual(channel.code, 200, channel.json_body) - - room_entry = channel.json_body["rooms"]["join"][self.room_id] - self.assertEqual( - room_entry["org.matrix.msc2654.unread_count"], - expected_count, - room_entry, - ) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] - - -class SyncCacheTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - sync.register_servlets, - ] - - def test_noop_sync_does_not_tightloop(self): - """If the sync times out, we shouldn't cache the result - - Essentially a regression test for #8518. - """ - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # we should immediately get an initial sync response - channel = self.make_request("GET", "/sync", access_token=self.tok) - self.assertEqual(channel.code, 200, channel.json_body) - - # now, make an incremental sync request, with a timeout - next_batch = channel.json_body["next_batch"] - channel = self.make_request( - "GET", - f"/sync?since={next_batch}&timeout=10000", - access_token=self.tok, - await_result=False, - ) - # that should block for 10 seconds - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=9900) - channel.await_result(timeout_ms=200) - self.assertEqual(channel.code, 200, channel.json_body) - - # we expect the next_batch in the result to be the same as before - self.assertEqual(channel.json_body["next_batch"], next_batch) - - # another incremental sync should also block. - channel = self.make_request( - "GET", - f"/sync?since={next_batch}&timeout=10000", - access_token=self.tok, - await_result=False, - ) - # that should block for 10 seconds - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=9900) - channel.await_result(timeout_ms=200) - self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py deleted file mode 100644 index 72f976d8e2..0000000000 --- a/tests/rest/client/v2_alpha/test_upgrade_room.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.rest import admin -from synapse.rest.client import login, room, room_upgrade_rest_servlet - -from tests import unittest -from tests.server import FakeChannel - - -class UpgradeRoomTest(unittest.HomeserverTestCase): - servlets = [ - admin.register_servlets, - login.register_servlets, - room.register_servlets, - room_upgrade_rest_servlet.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.handler = hs.get_user_directory_handler() - - self.creator = self.register_user("creator", "pass") - self.creator_token = self.login(self.creator, "pass") - - self.other = self.register_user("user", "pass") - self.other_token = self.login(self.other, "pass") - - self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token) - self.helper.join(self.room_id, self.other, tok=self.other_token) - - def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel: - # We never want a cached response. - self.reactor.advance(5 * 60 + 1) - - return self.make_request( - "POST", - "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id, - # This will upgrade a room to the same version, but that's fine. - content={"new_version": DEFAULT_ROOM_VERSION}, - access_token=token or self.creator_token, - ) - - def test_upgrade(self): - """ - Upgrading a room should work fine. - """ - channel = self._upgrade_room() - self.assertEquals(200, channel.code, channel.result) - self.assertIn("replacement_room", channel.json_body) - - def test_not_in_room(self): - """ - Upgrading a room should work fine. - """ - # THe user isn't in the room. - roomless = self.register_user("roomless", "pass") - roomless_token = self.login(roomless, "pass") - - channel = self._upgrade_room(roomless_token) - self.assertEquals(403, channel.code, channel.result) - - def test_power_levels(self): - """ - Another user can upgrade the room if their power level is increased. - """ - # The other user doesn't have the proper power level. - channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) - - # Increase the power levels so that this user can upgrade. - power_levels = self.helper.get_state( - self.room_id, - "m.room.power_levels", - tok=self.creator_token, - ) - power_levels["users"][self.other] = 100 - self.helper.send_state( - self.room_id, - "m.room.power_levels", - body=power_levels, - tok=self.creator_token, - ) - - # The upgrade should succeed! - channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) - - def test_power_levels_user_default(self): - """ - Another user can upgrade the room if the default power level for users is increased. - """ - # The other user doesn't have the proper power level. - channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) - - # Increase the power levels so that this user can upgrade. - power_levels = self.helper.get_state( - self.room_id, - "m.room.power_levels", - tok=self.creator_token, - ) - power_levels["users_default"] = 100 - self.helper.send_state( - self.room_id, - "m.room.power_levels", - body=power_levels, - tok=self.creator_token, - ) - - # The upgrade should succeed! - channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) - - def test_power_levels_tombstone(self): - """ - Another user can upgrade the room if they can send the tombstone event. - """ - # The other user doesn't have the proper power level. - channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) - - # Increase the power levels so that this user can upgrade. - power_levels = self.helper.get_state( - self.room_id, - "m.room.power_levels", - tok=self.creator_token, - ) - power_levels["events"]["m.room.tombstone"] = 0 - self.helper.send_state( - self.room_id, - "m.room.power_levels", - body=power_levels, - tok=self.creator_token, - ) - - # The upgrade should succeed! - channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) - - power_levels = self.helper.get_state( - self.room_id, - "m.room.power_levels", - tok=self.creator_token, - ) - self.assertNotIn(self.other, power_levels["users"]) diff --git a/tests/unittest.py b/tests/unittest.py index 3eec9c4d5b..f2c90cc47b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -252,7 +252,7 @@ class HomeserverTestCase(TestCase): reactor=self.reactor, ) - from tests.rest.client.v1.utils import RestHelper + from tests.rest.client.utils import RestHelper self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None)) -- cgit 1.5.1 From 947dbbdfd1e0029da66f956d277b7c089928e1e7 Mon Sep 17 00:00:00 2001 From: Callum Brown Date: Sat, 21 Aug 2021 22:14:43 +0100 Subject: Implement MSC3231: Token authenticated registration (#10142) Signed-off-by: Callum Brown This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231). --- changelog.d/10142.feature | 1 + docs/SUMMARY.md | 1 + docs/sample_config.yaml | 15 + .../admin_api/registration_tokens.md | 295 +++++++++ docs/workers.md | 1 + synapse/api/constants.py | 1 + synapse/app/generic_worker.py | 6 +- synapse/config/ratelimiting.py | 11 + synapse/config/registration.py | 15 + synapse/handlers/ui_auth/__init__.py | 5 + synapse/handlers/ui_auth/checkers.py | 65 ++ synapse/res/templates/registration_token.html | 23 + synapse/rest/admin/__init__.py | 8 + synapse/rest/admin/registration_tokens.py | 321 ++++++++++ synapse/rest/client/auth.py | 24 + synapse/rest/client/register.py | 72 +++ synapse/storage/databases/main/registration.py | 316 +++++++++ synapse/storage/databases/main/ui_auth.py | 43 ++ .../main/delta/63/01create_registration_tokens.sql | 23 + tests/rest/admin/test_registration_tokens.py | 710 +++++++++++++++++++++ tests/rest/client/test_register.py | 434 +++++++++++++ 21 files changed, 2389 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10142.feature create mode 100644 docs/usage/administration/admin_api/registration_tokens.md create mode 100644 synapse/res/templates/registration_token.html create mode 100644 synapse/rest/admin/registration_tokens.py create mode 100644 synapse/storage/schema/main/delta/63/01create_registration_tokens.sql create mode 100644 tests/rest/admin/test_registration_tokens.py (limited to 'tests') diff --git a/changelog.d/10142.feature b/changelog.d/10142.feature new file mode 100644 index 0000000000..5353f6269d --- /dev/null +++ b/changelog.d/10142.feature @@ -0,0 +1 @@ +Add support for [MSC3231 - Token authenticated registration](https://github.com/matrix-org/matrix-doc/pull/3231). Users can be required to submit a token during registration to authenticate themselves. Contributed by Callum Brown. diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 634bc833ab..4fcd2b7852 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -53,6 +53,7 @@ - [Media](admin_api/media_admin_api.md) - [Purge History](admin_api/purge_history_api.md) - [Register Users](admin_api/register_api.md) + - [Registration Tokens](usage/administration/admin_api/registration_tokens.md) - [Manipulate Room Membership](admin_api/room_membership.md) - [Rooms](admin_api/rooms.md) - [Server Notices](admin_api/server_notices.md) diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 2b0c453242..935841dbfa 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -793,6 +793,8 @@ log_config: "CONFDIR/SERVERNAME.log.config" # is using # - one for registration that ratelimits registration requests based on the # client's IP address. +# - one for checking the validity of registration tokens that ratelimits +# requests based on the client's IP address. # - one for login that ratelimits login requests based on the client's IP # address. # - one for login that ratelimits login requests based on the account the @@ -821,6 +823,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" # per_second: 0.17 # burst_count: 3 # +#rc_registration_token_validity: +# per_second: 0.1 +# burst_count: 5 +# #rc_login: # address: # per_second: 0.17 @@ -1169,6 +1175,15 @@ url_preview_accept_language: # #enable_3pid_lookup: true +# Require users to submit a token during registration. +# Tokens can be managed using the admin API: +# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html +# Note that `enable_registration` must be set to `true`. +# Disabling this option will not delete any tokens previously generated. +# Defaults to false. Uncomment the following to require tokens: +# +#registration_requires_token: true + # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # diff --git a/docs/usage/administration/admin_api/registration_tokens.md b/docs/usage/administration/admin_api/registration_tokens.md new file mode 100644 index 0000000000..828c0277d6 --- /dev/null +++ b/docs/usage/administration/admin_api/registration_tokens.md @@ -0,0 +1,295 @@ +# Registration Tokens + +This API allows you to manage tokens which can be used to authenticate +registration requests, as proposed in [MSC3231](https://github.com/govynnus/matrix-doc/blob/token-registration/proposals/3231-token-authenticated-registration.md). +To use it, you will need to enable the `registration_requires_token` config +option, and authenticate by providing an `access_token` for a server admin: +see [Admin API](../../usage/administration/admin_api). +Note that this API is still experimental; not all clients may support it yet. + + +## Registration token objects + +Most endpoints make use of JSON objects that contain details about tokens. +These objects have the following fields: +- `token`: The token which can be used to authenticate registration. +- `uses_allowed`: The number of times the token can be used to complete a + registration before it becomes invalid. +- `pending`: The number of pending uses the token has. When someone uses + the token to authenticate themselves, the pending counter is incremented + so that the token is not used more than the permitted number of times. + When the person completes registration the pending counter is decremented, + and the completed counter is incremented. +- `completed`: The number of times the token has been used to successfully + complete a registration. +- `expiry_time`: The latest time the token is valid. Given as the number of + milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch). + To convert this into a human-readable form you can remove the milliseconds + and use the `date` command. For example, `date -d '@1625394937'`. + + +## List all tokens + +Lists all tokens and details about them. If the request is successful, the top +level JSON object will have a `registration_tokens` key which is an array of +registration token objects. + +``` +GET /_synapse/admin/v1/registration_tokens +``` + +Optional query parameters: +- `valid`: `true` or `false`. If `true`, only valid tokens are returned. + If `false`, only tokens that have expired or have had all uses exhausted are + returned. If omitted, all tokens are returned regardless of validity. + +Example: + +``` +GET /_synapse/admin/v1/registration_tokens +``` +``` +200 OK + +{ + "registration_tokens": [ + { + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null + }, + { + "token": "pqrs", + "uses_allowed": 2, + "pending": 1, + "completed": 1, + "expiry_time": null + }, + { + "token": "wxyz", + "uses_allowed": null, + "pending": 0, + "completed": 9, + "expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC + } + ] +} +``` + +Example using the `valid` query parameter: + +``` +GET /_synapse/admin/v1/registration_tokens?valid=false +``` +``` +200 OK + +{ + "registration_tokens": [ + { + "token": "pqrs", + "uses_allowed": 2, + "pending": 1, + "completed": 1, + "expiry_time": null + }, + { + "token": "wxyz", + "uses_allowed": null, + "pending": 0, + "completed": 9, + "expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC + } + ] +} +``` + + +## Get one token + +Get details about a single token. If the request is successful, the response +body will be a registration token object. + +``` +GET /_synapse/admin/v1/registration_tokens/ +``` + +Path parameters: +- `token`: The registration token to return details of. + +Example: + +``` +GET /_synapse/admin/v1/registration_tokens/abcd +``` +``` +200 OK + +{ + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null +} +``` + + +## Create token + +Create a new registration token. If the request is successful, the newly created +token will be returned as a registration token object in the response body. + +``` +POST /_synapse/admin/v1/registration_tokens/new +``` + +The request body must be a JSON object and can contain the following fields: +- `token`: The registration token. A string of no more than 64 characters that + consists only of characters matched by the regex `[A-Za-z0-9-_]`. + Default: randomly generated. +- `uses_allowed`: The integer number of times the token can be used to complete + a registration before it becomes invalid. + Default: `null` (unlimited uses). +- `expiry_time`: The latest time the token is valid. Given as the number of + milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch). + You could use, for example, `date '+%s000' -d 'tomorrow'`. + Default: `null` (token does not expire). +- `length`: The length of the token randomly generated if `token` is not + specified. Must be between 1 and 64 inclusive. Default: `16`. + +If a field is omitted the default is used. + +Example using defaults: + +``` +POST /_synapse/admin/v1/registration_tokens/new + +{} +``` +``` +200 OK + +{ + "token": "0M-9jbkf2t_Tgiw1", + "uses_allowed": null, + "pending": 0, + "completed": 0, + "expiry_time": null +} +``` + +Example specifying some fields: + +``` +POST /_synapse/admin/v1/registration_tokens/new + +{ + "token": "defg", + "uses_allowed": 1 +} +``` +``` +200 OK + +{ + "token": "defg", + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": null +} +``` + + +## Update token + +Update the number of allowed uses or expiry time of a token. If the request is +successful, the updated token will be returned as a registration token object +in the response body. + +``` +PUT /_synapse/admin/v1/registration_tokens/ +``` + +Path parameters: +- `token`: The registration token to update. + +The request body must be a JSON object and can contain the following fields: +- `uses_allowed`: The integer number of times the token can be used to complete + a registration before it becomes invalid. By setting `uses_allowed` to `0` + the token can be easily made invalid without deleting it. + If `null` the token will have an unlimited number of uses. +- `expiry_time`: The latest time the token is valid. Given as the number of + milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch). + If `null` the token will not expire. + +If a field is omitted its value is not modified. + +Example: + +``` +PUT /_synapse/admin/v1/registration_tokens/defg + +{ + "expiry_time": 4781243146000 // 2121-07-06 11:05:46 UTC +} +``` +``` +200 OK + +{ + "token": "defg", + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": 4781243146000 +} +``` + + +## Delete token + +Delete a registration token. If the request is successful, the response body +will be an empty JSON object. + +``` +DELETE /_synapse/admin/v1/registration_tokens/ +``` + +Path parameters: +- `token`: The registration token to delete. + +Example: + +``` +DELETE /_synapse/admin/v1/registration_tokens/wxyz +``` +``` +200 OK + +{} +``` + + +## Errors + +If a request fails a "standard error response" will be returned as defined in +the [Matrix Client-Server API specification](https://matrix.org/docs/spec/client_server/r0.6.1#api-standards). + +For example, if the token specified in a path parameter does not exist a +`404 Not Found` error will be returned. + +``` +GET /_synapse/admin/v1/registration_tokens/1234 +``` +``` +404 Not Found + +{ + "errcode": "M_NOT_FOUND", + "error": "No such registration token: 1234" +} +``` diff --git a/docs/workers.md b/docs/workers.md index 2e63f03452..3121241894 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -236,6 +236,7 @@ expressions: # Registration/login requests ^/_matrix/client/(api/v1|r0|unstable)/login$ ^/_matrix/client/(r0|unstable)/register$ + ^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$ # Event sending requests ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact diff --git a/synapse/api/constants.py b/synapse/api/constants.py index e0e24fddac..829061c870 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -79,6 +79,7 @@ class LoginType: TERMS = "m.login.terms" SSO = "m.login.sso" DUMMY = "m.login.dummy" + REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" # This is used in the `type` parameter for /register when called by diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 845e6a8220..fd2626dbe1 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -95,7 +95,10 @@ from synapse.rest.client.profile import ( ProfileRestServlet, ) from synapse.rest.client.push_rule import PushRuleRestServlet -from synapse.rest.client.register import RegisterRestServlet +from synapse.rest.client.register import ( + RegisterRestServlet, + RegistrationTokenValidityRestServlet, +) from synapse.rest.client.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.client.voip import VoipRestServlet @@ -279,6 +282,7 @@ class GenericWorkerServer(HomeServer): resource = JsonResource(self, canonical_json=False) RegisterRestServlet(self).register(resource) + RegistrationTokenValidityRestServlet(self).register(resource) login.register_servlets(self, resource) ThreepidRestServlet(self).register(resource) DevicesRestServlet(self).register(resource) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 7a8d5851c4..f856327bd8 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -79,6 +79,11 @@ class RatelimitConfig(Config): self.rc_registration = RateLimitConfig(config.get("rc_registration", {})) + self.rc_registration_token_validity = RateLimitConfig( + config.get("rc_registration_token_validity", {}), + defaults={"per_second": 0.1, "burst_count": 5}, + ) + rc_login_config = config.get("rc_login", {}) self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {})) self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {})) @@ -143,6 +148,8 @@ class RatelimitConfig(Config): # is using # - one for registration that ratelimits registration requests based on the # client's IP address. + # - one for checking the validity of registration tokens that ratelimits + # requests based on the client's IP address. # - one for login that ratelimits login requests based on the client's IP # address. # - one for login that ratelimits login requests based on the account the @@ -171,6 +178,10 @@ class RatelimitConfig(Config): # per_second: 0.17 # burst_count: 3 # + #rc_registration_token_validity: + # per_second: 0.1 + # burst_count: 5 + # #rc_login: # address: # per_second: 0.17 diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 0ad919b139..7cffdacfa5 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -33,6 +33,9 @@ class RegistrationConfig(Config): self.registrations_require_3pid = config.get("registrations_require_3pid", []) self.allowed_local_3pids = config.get("allowed_local_3pids", []) self.enable_3pid_lookup = config.get("enable_3pid_lookup", True) + self.registration_requires_token = config.get( + "registration_requires_token", False + ) self.registration_shared_secret = config.get("registration_shared_secret") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) @@ -140,6 +143,9 @@ class RegistrationConfig(Config): "mechanism by removing the `access_token_lifetime` option." ) + # The fallback template used for authenticating using a registration token + self.registration_token_template = self.read_template("registration_token.html") + # The success template used during fallback auth. self.fallback_success_template = self.read_template("auth_success.html") @@ -199,6 +205,15 @@ class RegistrationConfig(Config): # #enable_3pid_lookup: true + # Require users to submit a token during registration. + # Tokens can be managed using the admin API: + # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html + # Note that `enable_registration` must be set to `true`. + # Disabling this option will not delete any tokens previously generated. + # Defaults to false. Uncomment the following to require tokens: + # + #registration_requires_token: true + # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py index 4c3b669fae..13b0c61d2e 100644 --- a/synapse/handlers/ui_auth/__init__.py +++ b/synapse/handlers/ui_auth/__init__.py @@ -34,3 +34,8 @@ class UIAuthSessionDataConstants: # used by validate_user_via_ui_auth to store the mxid of the user we are validating # for. REQUEST_USER_ID = "request_user_id" + + # used during registration to store the registration token used (if required) so that: + # - we can prevent a token being used twice by one session + # - we can 'use up' the token after registration has successfully completed + REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 270541cc76..d3828dec6b 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -241,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): return await self._check_threepid("msisdn", authdict) +class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): + AUTH_TYPE = LoginType.REGISTRATION_TOKEN + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self.hs = hs + self._enabled = bool(hs.config.registration_requires_token) + self.store = hs.get_datastore() + + def is_enabled(self) -> bool: + return self._enabled + + async def check_auth(self, authdict: dict, clientip: str) -> Any: + if "token" not in authdict: + raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM) + if not isinstance(authdict["token"], str): + raise LoginError( + 400, "Registration token must be a string", Codes.INVALID_PARAM + ) + if "session" not in authdict: + raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM) + + # Get these here to avoid cyclic dependencies + from synapse.handlers.ui_auth import UIAuthSessionDataConstants + + auth_handler = self.hs.get_auth_handler() + + session = authdict["session"] + token = authdict["token"] + + # If the LoginType.REGISTRATION_TOKEN stage has already been completed, + # return early to avoid incrementing `pending` again. + stored_token = await auth_handler.get_session_data( + session, UIAuthSessionDataConstants.REGISTRATION_TOKEN + ) + if stored_token: + if token != stored_token: + raise LoginError( + 400, "Registration token has changed", Codes.INVALID_PARAM + ) + else: + return token + + if await self.store.registration_token_is_valid(token): + # Increment pending counter, so that if token has limited uses it + # can't be used up by someone else in the meantime. + await self.store.set_registration_token_pending(token) + # Store the token in the UIA session, so that once registration + # is complete `completed` can be incremented. + await auth_handler.set_session_data( + session, + UIAuthSessionDataConstants.REGISTRATION_TOKEN, + token, + ) + # The token will be stored as the result of the authentication stage + # in ui_auth_sessions_credentials. This allows the pending counter + # for tokens to be decremented when expired sessions are deleted. + return token + else: + raise LoginError( + 401, "Invalid registration token", errcode=Codes.UNAUTHORIZED + ) + + INTERACTIVE_AUTH_CHECKERS = [ DummyAuthChecker, TermsAuthChecker, RecaptchaAuthChecker, EmailIdentityAuthChecker, MsisdnAuthChecker, + RegistrationTokenAuthChecker, ] """A list of UserInteractiveAuthChecker classes""" diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html new file mode 100644 index 0000000000..4577ce1702 --- /dev/null +++ b/synapse/res/templates/registration_token.html @@ -0,0 +1,23 @@ + + +Authentication + + + + +
+
+ {% if error is defined %} +

Error: {{ error }}

+ {% endif %} +

+ Please enter a registration token. +

+ + + +
+
+ + diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 7f3051aef1..6e1c8736e1 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -36,6 +36,11 @@ from synapse.rest.admin.event_reports import ( ) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo +from synapse.rest.admin.registration_tokens import ( + ListRegistrationTokensRestServlet, + NewRegistrationTokenRestServlet, + RegistrationTokenRestServlet, +) from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, ForwardExtremitiesRestServlet, @@ -238,6 +243,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomEventContextServlet(hs).register(http_server) RateLimitRestServlet(hs).register(http_server) UsernameAvailableRestServlet(hs).register(http_server) + ListRegistrationTokensRestServlet(hs).register(http_server) + NewRegistrationTokenRestServlet(hs).register(http_server) + RegistrationTokenRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py new file mode 100644 index 0000000000..5a1c929d85 --- /dev/null +++ b/synapse/rest/admin/registration_tokens.py @@ -0,0 +1,321 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import string +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_boolean, + parse_json_object_from_request, +) +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ListRegistrationTokensRestServlet(RestServlet): + """List registration tokens. + + To list all tokens: + + GET /_synapse/admin/v1/registration_tokens + + 200 OK + + { + "registration_tokens": [ + { + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null + }, + { + "token": "wxyz", + "uses_allowed": null, + "pending": 0, + "completed": 9, + "expiry_time": 1625394937000 + } + ] + } + + The optional query parameter `valid` can be used to filter the response. + If it is `true`, only valid tokens are returned. If it is `false`, only + tokens that have expired or have had all uses exhausted are returned. + If it is omitted, all tokens are returned regardless of validity. + """ + + PATTERNS = admin_patterns("/registration_tokens$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + valid = parse_boolean(request, "valid") + token_list = await self.store.get_registration_tokens(valid) + return 200, {"registration_tokens": token_list} + + +class NewRegistrationTokenRestServlet(RestServlet): + """Create a new registration token. + + For example, to create a token specifying some fields: + + POST /_synapse/admin/v1/registration_tokens/new + + { + "token": "defg", + "uses_allowed": 1 + } + + 200 OK + + { + "token": "defg", + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": null + } + + Defaults are used for any fields not specified. + """ + + PATTERNS = admin_patterns("/registration_tokens/new$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + # A string of all the characters allowed to be in a registration_token + self.allowed_chars = string.ascii_letters + string.digits + "-_" + self.allowed_chars_set = set(self.allowed_chars) + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + body = parse_json_object_from_request(request) + + if "token" in body: + token = body["token"] + if not isinstance(token, str): + raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM) + if not (0 < len(token) <= 64): + raise SynapseError( + 400, + "token must not be empty and must not be longer than 64 characters", + Codes.INVALID_PARAM, + ) + if not set(token).issubset(self.allowed_chars_set): + raise SynapseError( + 400, + "token must consist only of characters matched by the regex [A-Za-z0-9-_]", + Codes.INVALID_PARAM, + ) + + else: + # Get length of token to generate (default is 16) + length = body.get("length", 16) + if not isinstance(length, int): + raise SynapseError( + 400, "length must be an integer", Codes.INVALID_PARAM + ) + if not (0 < length <= 64): + raise SynapseError( + 400, + "length must be greater than zero and not greater than 64", + Codes.INVALID_PARAM, + ) + + # Generate token + token = await self.store.generate_registration_token( + length, self.allowed_chars + ) + + uses_allowed = body.get("uses_allowed", None) + if not ( + uses_allowed is None + or (isinstance(uses_allowed, int) and uses_allowed >= 0) + ): + raise SynapseError( + 400, + "uses_allowed must be a non-negative integer or null", + Codes.INVALID_PARAM, + ) + + expiry_time = body.get("expiry_time", None) + if not isinstance(expiry_time, (int, type(None))): + raise SynapseError( + 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + ) + if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + raise SynapseError( + 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + ) + + created = await self.store.create_registration_token( + token, uses_allowed, expiry_time + ) + if not created: + raise SynapseError( + 400, f"Token already exists: {token}", Codes.INVALID_PARAM + ) + + resp = { + "token": token, + "uses_allowed": uses_allowed, + "pending": 0, + "completed": 0, + "expiry_time": expiry_time, + } + return 200, resp + + +class RegistrationTokenRestServlet(RestServlet): + """Retrieve, update, or delete the given token. + + For example, + + to retrieve a token: + + GET /_synapse/admin/v1/registration_tokens/abcd + + 200 OK + + { + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null + } + + + to update a token: + + PUT /_synapse/admin/v1/registration_tokens/defg + + { + "uses_allowed": 5, + "expiry_time": 4781243146000 + } + + 200 OK + + { + "token": "defg", + "uses_allowed": 5, + "pending": 0, + "completed": 0, + "expiry_time": 4781243146000 + } + + + to delete a token: + + DELETE /_synapse/admin/v1/registration_tokens/wxyz + + 200 OK + + {} + """ + + PATTERNS = admin_patterns("/registration_tokens/(?P[^/]*)$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.clock = hs.get_clock() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: + """Retrieve a registration token.""" + await assert_requester_is_admin(self.auth, request) + token_info = await self.store.get_one_registration_token(token) + + # If no result return a 404 + if token_info is None: + raise NotFoundError(f"No such registration token: {token}") + + return 200, token_info + + async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: + """Update a registration token.""" + await assert_requester_is_admin(self.auth, request) + body = parse_json_object_from_request(request) + new_attributes = {} + + # Only add uses_allowed to new_attributes if it is present and valid + if "uses_allowed" in body: + uses_allowed = body["uses_allowed"] + if not ( + uses_allowed is None + or (isinstance(uses_allowed, int) and uses_allowed >= 0) + ): + raise SynapseError( + 400, + "uses_allowed must be a non-negative integer or null", + Codes.INVALID_PARAM, + ) + new_attributes["uses_allowed"] = uses_allowed + + if "expiry_time" in body: + expiry_time = body["expiry_time"] + if not isinstance(expiry_time, (int, type(None))): + raise SynapseError( + 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + ) + if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + raise SynapseError( + 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + ) + new_attributes["expiry_time"] = expiry_time + + if len(new_attributes) == 0: + # Nothing to update, get token info to return + token_info = await self.store.get_one_registration_token(token) + else: + token_info = await self.store.update_registration_token( + token, new_attributes + ) + + # If no result return a 404 + if token_info is None: + raise NotFoundError(f"No such registration token: {token}") + + return 200, token_info + + async def on_DELETE( + self, request: SynapseRequest, token: str + ) -> Tuple[int, JsonDict]: + """Delete a registration token.""" + await assert_requester_is_admin(self.auth, request) + + if await self.store.delete_registration_token(token): + return 200, {} + + raise NotFoundError(f"No such registration token: {token}") diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 73284e48ec..91800c0278 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.recaptcha_template self.terms_template = hs.config.terms_template + self.registration_token_template = hs.config.registration_token_template self.success_template = hs.config.fallback_success_template async def on_GET(self, request, stagetype): @@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet): # re-authenticate with their SSO provider. html = await self.auth_handler.start_sso_ui_auth(request, session) + elif stagetype == LoginType.REGISTRATION_TOKEN: + html = self.registration_token_template.render( + session=session, + myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web", + ) + else: raise SynapseError(404, "Unknown auth stage type") @@ -140,6 +147,23 @@ class AuthRestServlet(RestServlet): # The SSO fallback workflow should not post here, raise SynapseError(404, "Fallback SSO auth does not support POST requests.") + elif stagetype == LoginType.REGISTRATION_TOKEN: + token = parse_string(request, "token", required=True) + authdict = {"session": session, "token": token} + + try: + await self.auth_handler.add_oob_auth( + LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP() + ) + except LoginError as e: + html = self.registration_token_template.render( + session=session, + myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web", + error=e.msg, + ) + else: + html = self.success_template.render() + else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 58b8e8f261..2781a0ea96 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -28,6 +28,7 @@ from synapse.api.errors import ( ThreepidValidationError, UnrecognizedRequestError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.config import ConfigError from synapse.config.captcha import CaptchaConfig from synapse.config.consent import ConsentConfig @@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet): return 200, {"available": True} +class RegistrationTokenValidityRestServlet(RestServlet): + """Check the validity of a registration token. + + Example: + + GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd + + 200 OK + + { + "valid": true + } + """ + + PATTERNS = client_patterns( + f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity", + releases=(), + unstable=True, + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.store = hs.get_datastore() + self.ratelimiter = Ratelimiter( + store=self.store, + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second, + burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, + ) + + async def on_GET(self, request): + await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) + + if not self.hs.config.enable_registration: + raise SynapseError( + 403, "Registration has been disabled", errcode=Codes.FORBIDDEN + ) + + token = parse_string(request, "token", required=True) + valid = await self.store.registration_token_is_valid(token) + + return 200, {"valid": valid} + + class RegisterRestServlet(RestServlet): PATTERNS = client_patterns("/register$") @@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet): ) if registered: + # Check if a token was used to authenticate registration + registration_token = await self.auth_handler.get_session_data( + session_id, + UIAuthSessionDataConstants.REGISTRATION_TOKEN, + ) + if registration_token: + # Increment the `completed` counter for the token + await self.store.use_registration_token(registration_token) + # Indicate that the token has been successfully used so that + # pending is not decremented again when expiring old UIA sessions. + await self.store.mark_ui_auth_stage_complete( + session_id, + LoginType.REGISTRATION_TOKEN, + True, + ) + await self.registration_handler.post_registration_actions( user_id=registered_user_id, auth_result=auth_result, @@ -868,6 +934,11 @@ def _calculate_registration_flows( for flow in flows: flow.insert(0, LoginType.RECAPTCHA) + # Prepend registration token to all flows if we're requiring a token + if config.registration_requires_token: + for flow in flows: + flow.insert(0, LoginType.REGISTRATION_TOKEN) + return flows @@ -876,4 +947,5 @@ def register_servlets(hs, http_server): MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) UsernameAvailabilityRestServlet(hs).register(http_server) RegistrationSubmitTokenServlet(hs).register(http_server) + RegistrationTokenValidityRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 469dd53e0c..a6517962f6 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1168,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="update_access_token_last_validated", ) + async def registration_token_is_valid(self, token: str) -> bool: + """Checks if a token can be used to authenticate a registration. + + Args: + token: The registration token to be checked + Returns: + True if the token is valid, False otherwise. + """ + res = await self.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["uses_allowed", "pending", "completed", "expiry_time"], + allow_none=True, + ) + + # Check if the token exists + if res is None: + return False + + # Check if the token has expired + now = self._clock.time_msec() + if res["expiry_time"] and res["expiry_time"] < now: + return False + + # Check if the token has been used up + if ( + res["uses_allowed"] + and res["pending"] + res["completed"] >= res["uses_allowed"] + ): + return False + + # Otherwise, the token is valid + return True + + async def set_registration_token_pending(self, token: str) -> None: + """Increment the pending registrations counter for a token. + + Args: + token: The registration token pending use + """ + + def _set_registration_token_pending_txn(txn): + pending = self.db_pool.simple_select_one_onecol_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues={"pending": pending + 1}, + ) + + return await self.db_pool.runInteraction( + "set_registration_token_pending", _set_registration_token_pending_txn + ) + + async def use_registration_token(self, token: str) -> None: + """Complete a use of the given registration token. + + The `pending` counter will be decremented, and the `completed` + counter will be incremented. + + Args: + token: The registration token to be 'used' + """ + + def _use_registration_token_txn(txn): + # Normally, res is Optional[Dict[str, Any]]. + # Override type because the return type is only optional if + # allow_none is True, and we don't want mypy throwing errors + # about None not being indexable. + res: Dict[str, Any] = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) # type: ignore + + # Decrement pending and increment completed + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues={ + "completed": res["completed"] + 1, + "pending": res["pending"] - 1, + }, + ) + + return await self.db_pool.runInteraction( + "use_registration_token", _use_registration_token_txn + ) + + async def get_registration_tokens( + self, valid: Optional[bool] = None + ) -> List[Dict[str, Any]]: + """List all registration tokens. Used by the admin API. + + Args: + valid: If True, only valid tokens are returned. + If False, only invalid tokens are returned. + Default is None: return all tokens regardless of validity. + + Returns: + A list of dicts, each containing details of a token. + """ + + def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]): + if valid is None: + # Return all tokens regardless of validity + txn.execute("SELECT * FROM registration_tokens") + + elif valid: + # Select valid tokens only + sql = ( + "SELECT * FROM registration_tokens WHERE " + "(uses_allowed > pending + completed OR uses_allowed IS NULL) " + "AND (expiry_time > ? OR expiry_time IS NULL)" + ) + txn.execute(sql, [now]) + + else: + # Select invalid tokens only + sql = ( + "SELECT * FROM registration_tokens WHERE " + "uses_allowed <= pending + completed OR expiry_time <= ?" + ) + txn.execute(sql, [now]) + + return self.db_pool.cursor_to_dict(txn) + + return await self.db_pool.runInteraction( + "select_registration_tokens", + select_registration_tokens_txn, + self._clock.time_msec(), + valid, + ) + + async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]: + """Get info about the given registration token. Used by the admin API. + + Args: + token: The token to retrieve information about. + + Returns: + A dict, or None if token doesn't exist. + """ + return await self.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], + allow_none=True, + desc="get_one_registration_token", + ) + + async def generate_registration_token( + self, length: int, chars: str + ) -> Optional[str]: + """Generate a random registration token. Used by the admin API. + + Args: + length: The length of the token to generate. + chars: A string of the characters allowed in the generated token. + + Returns: + The generated token. + + Raises: + SynapseError if a unique registration token could still not be + generated after a few tries. + """ + # Make a few attempts at generating a unique token of the required + # length before failing. + for _i in range(3): + # Generate token + token = "".join(random.choices(chars, k=length)) + + # Check if the token already exists + existing_token = await self.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="token", + allow_none=True, + desc="check_if_registration_token_exists", + ) + + if existing_token is None: + # The generated token doesn't exist yet, return it + return token + + raise SynapseError( + 500, + "Unable to generate a unique registration token. Try again with a greater length", + Codes.UNKNOWN, + ) + + async def create_registration_token( + self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int] + ) -> bool: + """Create a new registration token. Used by the admin API. + + Args: + token: The token to create. + uses_allowed: The number of times the token can be used to complete + a registration before it becomes invalid. A value of None indicates + unlimited uses. + expiry_time: The latest time the token is valid. Given as the + number of milliseconds since 1970-01-01 00:00:00 UTC. A value of + None indicates that the token does not expire. + + Returns: + Whether the row was inserted or not. + """ + + def _create_registration_token_txn(txn): + row = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["token"], + allow_none=True, + ) + + if row is not None: + # Token already exists + return False + + self.db_pool.simple_insert_txn( + txn, + "registration_tokens", + values={ + "token": token, + "uses_allowed": uses_allowed, + "pending": 0, + "completed": 0, + "expiry_time": expiry_time, + }, + ) + + return True + + return await self.db_pool.runInteraction( + "create_registration_token", _create_registration_token_txn + ) + + async def update_registration_token( + self, token: str, updatevalues: Dict[str, Optional[int]] + ) -> Optional[Dict[str, Any]]: + """Update a registration token. Used by the admin API. + + Args: + token: The token to update. + updatevalues: A dict with the fields to update. E.g.: + `{"uses_allowed": 3}` to update just uses_allowed, or + `{"uses_allowed": 3, "expiry_time": None}` to update both. + This is passed straight to simple_update_one. + + Returns: + A dict with all info about the token, or None if token doesn't exist. + """ + + def _update_registration_token_txn(txn): + try: + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues=updatevalues, + ) + except StoreError: + # Update failed because token does not exist + return None + + # Get all info about the token so it can be sent in the response + return self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=[ + "token", + "uses_allowed", + "pending", + "completed", + "expiry_time", + ], + allow_none=True, + ) + + return await self.db_pool.runInteraction( + "update_registration_token", _update_registration_token_txn + ) + + async def delete_registration_token(self, token: str) -> bool: + """Delete a registration token. Used by the admin API. + + Args: + token: The token to delete. + + Returns: + Whether the token was successfully deleted or not. + """ + try: + await self.db_pool.simple_delete_one( + "registration_tokens", + keyvalues={"token": token}, + desc="delete_registration_token", + ) + except StoreError: + # Deletion failed because token does not exist + return False + + return True + @cached() async def mark_access_token_as_used(self, token_id: int) -> None: """ diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 38bfdf5dad..4d6bbc94c7 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import attr +from synapse.api.constants import LoginType from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction @@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore): keyvalues={}, ) + # If a registration token was used, decrement the pending counter + # before deleting the session. + rows = self.db_pool.simple_select_many_txn( + txn, + table="ui_auth_sessions_credentials", + column="session_id", + iterable=session_ids, + keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN}, + retcols=["result"], + ) + + # Get the tokens used and how much pending needs to be decremented by. + token_counts: Dict[str, int] = {} + for r in rows: + # If registration was successfully completed, the result of the + # registration token stage for that session will be True. + # If a token was used to authenticate, but registration was + # never completed, the result will be the token used. + token = db_to_json(r["result"]) + if isinstance(token, str): + token_counts[token] = token_counts.get(token, 0) + 1 + + # Update the `pending` counters. + if len(token_counts) > 0: + token_rows = self.db_pool.simple_select_many_txn( + txn, + table="registration_tokens", + column="token", + iterable=list(token_counts.keys()), + keyvalues={}, + retcols=["token", "pending"], + ) + for token_row in token_rows: + token = token_row["token"] + new_pending = token_row["pending"] - token_counts[token] + self.db_pool.simple_update_one_txn( + txn, + table="registration_tokens", + keyvalues={"token": token}, + updatevalues={"pending": new_pending}, + ) + # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, diff --git a/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql new file mode 100644 index 0000000000..ee6cf958f4 --- /dev/null +++ b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql @@ -0,0 +1,23 @@ +/* Copyright 2021 Callum Brown + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS registration_tokens( + token TEXT NOT NULL, -- The token that can be used for authentication. + uses_allowed INT, -- The total number of times this token can be used. NULL if no limit. + pending INT NOT NULL, -- The number of in progress registrations using this token. + completed INT NOT NULL, -- The number of times this token has been used to complete a registration. + expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire. + UNIQUE (token) +); diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py new file mode 100644 index 0000000000..4927321e5a --- /dev/null +++ b/tests/rest/admin/test_registration_tokens.py @@ -0,0 +1,710 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import string + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login + +from tests import unittest + + +class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.url = "/_synapse/admin/v1/registration_tokens" + + def _new_token(self, **kwargs): + """Helper function to create a token.""" + token = kwargs.get( + "token", + "".join(random.choices(string.ascii_letters, k=8)), + ) + self.get_success( + self.store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": kwargs.get("uses_allowed", None), + "pending": kwargs.get("pending", 0), + "completed": kwargs.get("completed", 0), + "expiry_time": kwargs.get("expiry_time", None), + }, + ) + ) + return token + + # CREATION + + def test_create_no_auth(self): + """Try to create a token without authentication.""" + channel = self.make_request("POST", self.url + "/new", {}) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_create_requester_not_admin(self): + """Try to create a token while not an admin.""" + channel = self.make_request( + "POST", + self.url + "/new", + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_create_using_defaults(self): + """Create a token using all the defaults.""" + channel = self.make_request( + "POST", + self.url + "/new", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 16) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_specifying_fields(self): + """Create a token specifying the value of all fields.""" + data = { + "token": "abcd", + "uses_allowed": 1, + "expiry_time": self.clock.time_msec() + 1000000, + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["token"], "abcd") + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_with_null_value(self): + """Create a token specifying unlimited uses and no expiry.""" + data = { + "uses_allowed": None, + "expiry_time": None, + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 16) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_token_too_long(self): + """Check token longer than 64 chars is invalid.""" + data = {"token": "a" * 65} + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_token_invalid_chars(self): + """Check you can't create token with invalid characters.""" + data = { + "token": "abc/def", + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_token_already_exists(self): + """Check you can't create token that already exists.""" + data = { + "token": "abcd", + } + + channel1 = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"]) + + channel2 = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"]) + self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_unable_to_generate_token(self): + """Check right error is raised when server can't generate unique token.""" + # Create all possible single character tokens + tokens = [] + for c in string.ascii_letters + string.digits + "-_": + tokens.append( + { + "token": c, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + } + ) + self.get_success( + self.store.db_pool.simple_insert_many( + "registration_tokens", + tokens, + "create_all_registration_tokens", + ) + ) + + # Check creating a single character token fails with a 500 status code + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 1}, + access_token=self.admin_user_tok, + ) + self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"]) + + def test_create_uses_allowed(self): + """Check you can only create a token with good values for uses_allowed.""" + # Should work with 0 (token is invalid from the start) + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 0) + + # Should fail with negative integer + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with float + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": 1.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_expiry_time(self): + """Check you can't create a token with an invalid expiry_time.""" + # Should fail with a time in the past + channel = self.make_request( + "POST", + self.url + "/new", + {"expiry_time": self.clock.time_msec() - 10000}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with float + channel = self.make_request( + "POST", + self.url + "/new", + {"expiry_time": self.clock.time_msec() + 1000000.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_length(self): + """Check you can only generate a token with a valid length.""" + # Should work with 64 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 64}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 64) + + # Should fail with 0 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a negative integer + channel = self.make_request( + "POST", + self.url + "/new", + {"length": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a float + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 8.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with 65 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 65}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # UPDATING + + def test_update_no_auth(self): + """Try to update a token without authentication.""" + channel = self.make_request( + "PUT", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_update_requester_not_admin(self): + """Try to update a token while not an admin.""" + channel = self.make_request( + "PUT", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_update_non_existent(self): + """Try to update a token that doesn't exist.""" + channel = self.make_request( + "PUT", + self.url + "/1234", + {"uses_allowed": 1}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_update_uses_allowed(self): + """Test updating just uses_allowed.""" + # Create new token using default values + token = self._new_token() + + # Should succeed with 1 + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 1}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should succeed with 0 (makes token invalid) + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 0) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should succeed with null + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": None}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should fail with a float + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 1.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a negative integer + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_update_expiry_time(self): + """Test updating just expiry_time.""" + # Create new token using default values + token = self._new_token() + new_expiry_time = self.clock.time_msec() + 1000000 + + # Should succeed with a time in the future + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": new_expiry_time}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) + self.assertIsNone(channel.json_body["uses_allowed"]) + + # Should succeed with null + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": None}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertIsNone(channel.json_body["uses_allowed"]) + + # Should fail with a time in the past + past_time = self.clock.time_msec() - 10000 + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": past_time}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail a float + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": new_expiry_time + 0.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_update_both(self): + """Test updating both uses_allowed and expiry_time.""" + # Create new token using default values + token = self._new_token() + new_expiry_time = self.clock.time_msec() + 1000000 + + data = { + "uses_allowed": 1, + "expiry_time": new_expiry_time, + } + + channel = self.make_request( + "PUT", + self.url + "/" + token, + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) + + def test_update_invalid_type(self): + """Test using invalid types doesn't work.""" + # Create new token using default values + token = self._new_token() + + data = { + "uses_allowed": False, + "expiry_time": "1626430124000", + } + + channel = self.make_request( + "PUT", + self.url + "/" + token, + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # DELETING + + def test_delete_no_auth(self): + """Try to delete a token without authentication.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_delete_requester_not_admin(self): + """Try to delete a token while not an admin.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_delete_non_existent(self): + """Try to delete a token that doesn't exist.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_delete(self): + """Test deleting a token.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "DELETE", + self.url + "/" + token, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # GETTING ONE + + def test_get_no_auth(self): + """Try to get a token without authentication.""" + channel = self.make_request( + "GET", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_get_requester_not_admin(self): + """Try to get a token while not an admin.""" + channel = self.make_request( + "GET", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_get_non_existent(self): + """Try to get a token that doesn't exist.""" + channel = self.make_request( + "GET", + self.url + "/1234", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_get(self): + """Test getting a token.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "GET", + self.url + "/" + token, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["token"], token) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + # LISTING + + def test_list_no_auth(self): + """Try to list tokens without authentication.""" + channel = self.make_request("GET", self.url, {}) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_list_requester_not_admin(self): + """Try to list tokens while not an admin.""" + channel = self.make_request( + "GET", + self.url, + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_list_all(self): + """Test listing all tokens.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "GET", + self.url, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["registration_tokens"]), 1) + token_info = channel.json_body["registration_tokens"][0] + self.assertEqual(token_info["token"], token) + self.assertIsNone(token_info["uses_allowed"]) + self.assertIsNone(token_info["expiry_time"]) + self.assertEqual(token_info["pending"], 0) + self.assertEqual(token_info["completed"], 0) + + def test_list_invalid_query_parameter(self): + """Test with `valid` query parameter not `true` or `false`.""" + channel = self.make_request( + "GET", + self.url + "?valid=x", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + def _test_list_query_parameter(self, valid: str): + """Helper used to test both valid=true and valid=false.""" + # Create 2 valid and 2 invalid tokens. + now = self.hs.get_clock().time_msec() + # Create always valid token + valid1 = self._new_token() + # Create token that hasn't been used up + valid2 = self._new_token(uses_allowed=1) + # Create token that has expired + invalid1 = self._new_token(expiry_time=now - 10000) + # Create token that has been used up but hasn't expired + invalid2 = self._new_token( + uses_allowed=2, + pending=1, + completed=1, + expiry_time=now + 1000000, + ) + + if valid == "true": + tokens = [valid1, valid2] + else: + tokens = [invalid1, invalid2] + + channel = self.make_request( + "GET", + self.url + "?valid=" + valid, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["registration_tokens"]), 2) + token_info_1 = channel.json_body["registration_tokens"][0] + token_info_2 = channel.json_body["registration_tokens"][1] + self.assertIn(token_info_1["token"], tokens) + self.assertIn(token_info_2["token"], tokens) + + def test_list_valid(self): + """Test listing just valid tokens.""" + self._test_list_query_parameter(valid="true") + + def test_list_invalid(self): + """Test listing just invalid tokens.""" + self._test_list_query_parameter(valid="false") diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index fecda037a5..9f3ab2c985 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import account, account_validity, login, logout, register, sync +from synapse.storage._base import db_to_json from tests import unittest from tests.unittest import override_config @@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config({"registration_requires_token": True}) + def test_POST_registration_requires_token(self): + username = "kermit" + device_id = "frogfone" + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + params = { + "username": username, + "password": "monkey", + "device_id": device_id, + } + + # Request without auth to get flows and session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + # Synapse adds a dummy stage to differentiate flows where otherwise one + # flow would be a subset of another flow. + self.assertCountEqual( + [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]], + (f["stages"] for f in flows), + ) + session = channel.json_body["session"] + + # Do the registration token stage and check it has completed + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + self.assertEquals(channel.result["code"], b"401", channel.result) + completed = channel.json_body["completed"] + self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) + + # Do the m.login.dummy stage and check registration was successful + params["auth"] = { + "type": LoginType.DUMMY, + "session": session, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + det_data = { + "user_id": f"@{username}:{self.hs.hostname}", + "home_server": self.hs.hostname, + "device_id": device_id, + } + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, channel.json_body) + + # Check the `completed` counter has been incremented and pending is 0 + res = self.get_success( + store.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) + ) + self.assertEquals(res["completed"], 1) + self.assertEquals(res["pending"], 0) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_invalid(self): + params = { + "username": "kermit", + "password": "monkey", + } + # Request without auth to get session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Test with token param missing (invalid) + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "session": session, + } + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM) + self.assertEquals(channel.json_body["completed"], []) + + # Test with non-string (invalid) + params["auth"]["token"] = 1234 + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM) + self.assertEquals(channel.json_body["completed"], []) + + # Test with unknown token (invalid) + params["auth"]["token"] = "1234" + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_limit_uses(self): + token = "abcd" + store = self.hs.get_datastore() + # Create token that can be used once + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + params1 = {"username": "bert", "password": "monkey"} + params2 = {"username": "ernie", "password": "monkey"} + # Do 2 requests without auth to get two session IDs + channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + session1 = channel1.json_body["session"] + channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + session2 = channel2.json_body["session"] + + # Use token with session1 and check `pending` is 1 + params1["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session1, + } + self.make_request(b"POST", self.url, json.dumps(params1)) + # Repeat request to make sure pending isn't increased again + self.make_request(b"POST", self.url, json.dumps(params1)) + pending = self.get_success( + store.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + ) + self.assertEquals(pending, 1) + + # Check auth fails when using token with session2 + params2["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session2, + } + channel = self.make_request(b"POST", self.url, json.dumps(params2)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + # Complete registration with session1 + params1["auth"]["type"] = LoginType.DUMMY + self.make_request(b"POST", self.url, json.dumps(params1)) + # Check pending=0 and completed=1 + res = self.get_success( + store.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) + ) + self.assertEquals(res["pending"], 0) + self.assertEquals(res["completed"], 1) + + # Check auth still fails when using token with session2 + channel = self.make_request(b"POST", self.url, json.dumps(params2)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_expiry(self): + token = "abcd" + now = self.hs.get_clock().time_msec() + store = self.hs.get_datastore() + # Create token that expired yesterday + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": now - 24 * 60 * 60 * 1000, + }, + ) + ) + params = {"username": "kermit", "password": "monkey"} + # Request without auth to get session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Check authentication fails with expired token + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + # Update token so it expires tomorrow + self.get_success( + store.db_pool.simple_update_one( + "registration_tokens", + keyvalues={"token": token}, + updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000}, + ) + ) + + # Check authentication succeeds + channel = self.make_request(b"POST", self.url, json.dumps(params)) + completed = channel.json_body["completed"] + self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_session_expiry(self): + """Test `pending` is decremented when an uncompleted session expires.""" + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + # Do 2 requests without auth to get two session IDs + params1 = {"username": "bert", "password": "monkey"} + params2 = {"username": "ernie", "password": "monkey"} + channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + session1 = channel1.json_body["session"] + channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + session2 = channel2.json_body["session"] + + # Use token with both sessions + params1["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session1, + } + self.make_request(b"POST", self.url, json.dumps(params1)) + + params2["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session2, + } + self.make_request(b"POST", self.url, json.dumps(params2)) + + # Complete registration with session1 + params1["auth"]["type"] = LoginType.DUMMY + self.make_request(b"POST", self.url, json.dumps(params1)) + + # Check `result` of registration token stage for session1 is `True` + result1 = self.get_success( + store.db_pool.simple_select_one_onecol( + "ui_auth_sessions_credentials", + keyvalues={ + "session_id": session1, + "stage_type": LoginType.REGISTRATION_TOKEN, + }, + retcol="result", + ) + ) + self.assertTrue(db_to_json(result1)) + + # Check `result` for session2 is the token used + result2 = self.get_success( + store.db_pool.simple_select_one_onecol( + "ui_auth_sessions_credentials", + keyvalues={ + "session_id": session2, + "stage_type": LoginType.REGISTRATION_TOKEN, + }, + retcol="result", + ) + ) + self.assertEquals(db_to_json(result2), token) + + # Delete both sessions (mimics expiry) + self.get_success( + store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) + ) + + # Check pending is now 0 + pending = self.get_success( + store.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + ) + self.assertEquals(pending, 0) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_session_expiry_deleted_token(self): + """Test session expiry doesn't break when the token is deleted. + + 1. Start but don't complete UIA with a registration token + 2. Delete the token from the database + 3. Expire the session + """ + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + # Do request without auth to get a session ID + params = {"username": "kermit", "password": "monkey"} + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Use token + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + self.make_request(b"POST", self.url, json.dumps(params)) + + # Delete token + self.get_success( + store.db_pool.simple_delete_one( + "registration_tokens", + keyvalues={"token": token}, + ) + ) + + # Delete session (mimics expiry) + self.get_success( + store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) + ) + def test_advertised_flows(self): channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) @@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) self.assertLessEqual(res, now_ms + self.validity_period) + + +class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): + servlets = [register.register_servlets] + url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity" + + def default_config(self): + config = super().default_config() + config["registration_requires_token"] = True + return config + + def test_GET_token_valid(self): + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEquals(channel.json_body["valid"], True) + + def test_GET_token_invalid(self): + token = "1234" + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEquals(channel.json_body["valid"], False) + + @override_config( + {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}} + ) + def test_GET_ratelimiting(self): + token = "1234" + + for i in range(0, 6): + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) -- cgit 1.5.1 From 7367473f965fed1160cb8633de341c5833e5b662 Mon Sep 17 00:00:00 2001 From: Sean Date: Wed, 25 Aug 2021 10:51:08 +0100 Subject: Fix error when selecting between thumbnails with the same quality (#10684) Fixes #10318 --- changelog.d/10684.bugfix | 1 + synapse/rest/media/v1/thumbnail_resource.py | 26 ++++++++++++------- tests/rest/media/v1/test_media_storage.py | 39 ++++++++++++++++++++++++++++- 3 files changed, 56 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10684.bugfix (limited to 'tests') diff --git a/changelog.d/10684.bugfix b/changelog.d/10684.bugfix new file mode 100644 index 0000000000..311b17601a --- /dev/null +++ b/changelog.d/10684.bugfix @@ -0,0 +1 @@ +Fix long-standing issue which caused an error when a thumbnail is requested and there are multiple thumbnails with the same quality rating. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index a029d426f0..12bd745cb2 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -15,7 +15,7 @@ import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from twisted.web.server import Request @@ -414,9 +414,9 @@ class ThumbnailResource(DirectServeJsonResource): if desired_method == "crop": # Thumbnails that match equal or larger sizes of desired width/height. - crop_info_list = [] + crop_info_list: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = [] # Other thumbnails. - crop_info_list2 = [] + crop_info_list2: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = [] for info in thumbnail_infos: # Skip thumbnails generated with different methods. if info["thumbnail_method"] != "crop": @@ -451,15 +451,19 @@ class ThumbnailResource(DirectServeJsonResource): info, ) ) + # Pick the most appropriate thumbnail. Some values of `desired_width` and + # `desired_height` may result in a tie, in which case we avoid comparing on + # the thumbnail info dictionary and pick the thumbnail that appears earlier + # in the list of candidates. if crop_info_list: - thumbnail_info = min(crop_info_list)[-1] + thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1] elif crop_info_list2: - thumbnail_info = min(crop_info_list2)[-1] + thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1] elif desired_method == "scale": # Thumbnails that match equal or larger sizes of desired width/height. - info_list = [] + info_list: List[Tuple[int, bool, int, Dict[str, Any]]] = [] # Other thumbnails. - info_list2 = [] + info_list2: List[Tuple[int, bool, int, Dict[str, Any]]] = [] for info in thumbnail_infos: # Skip thumbnails generated with different methods. @@ -477,10 +481,14 @@ class ThumbnailResource(DirectServeJsonResource): info_list2.append( (size_quality, type_quality, length_quality, info) ) + # Pick the most appropriate thumbnail. Some values of `desired_width` and + # `desired_height` may result in a tie, in which case we avoid comparing on + # the thumbnail info dictionary and pick the thumbnail that appears earlier + # in the list of candidates. if info_list: - thumbnail_info = min(info_list)[-1] + thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1] elif info_list2: - thumbnail_info = min(info_list2)[-1] + thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1] if thumbnail_info: return FileInfo( diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 6085444b9d..2f7eebfe69 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -21,7 +21,7 @@ from unittest.mock import Mock from urllib import parse import attr -from parameterized import parameterized_class +from parameterized import parameterized, parameterized_class from PIL import Image as Image from twisted.internet import defer @@ -473,6 +473,43 @@ class MediaRepoTests(unittest.HomeserverTestCase): }, ) + @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)]) + def test_same_quality(self, method, desired_size): + """Test that choosing between thumbnails with the same quality rating succeeds. + + We are not particular about which thumbnail is chosen.""" + self.assertIsNotNone( + self.thumbnail_resource._select_thumbnail( + desired_width=desired_size, + desired_height=desired_size, + desired_method=method, + desired_type=self.test_image.content_type, + # Provide two identical thumbnails which are guaranteed to have the same + # quality rating. + thumbnail_infos=[ + { + "thumbnail_width": 32, + "thumbnail_height": 32, + "thumbnail_method": method, + "thumbnail_type": self.test_image.content_type, + "thumbnail_length": 256, + "filesystem_id": f"thumbnail1{self.test_image.extension}", + }, + { + "thumbnail_width": 32, + "thumbnail_height": 32, + "thumbnail_method": method, + "thumbnail_type": self.test_image.content_type, + "thumbnail_length": 256, + "filesystem_id": f"thumbnail2{self.test_image.extension}", + }, + ], + file_id=f"image{self.test_image.extension}", + url_cache=None, + server_name=None, + ) + ) + def test_x_robots_tag_header(self): """ Tests that the `X-Robots-Tag` header is present, which informs web crawlers -- cgit 1.5.1 From ad17fbd20eb2dd9fb10a3d02ab1b69e9a0d5b50c Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Thu, 26 Aug 2021 13:53:57 +0100 Subject: Remove pushers when deleting 3pid from account (#10581) When a user deletes an email from their account it will now also remove all pushers for that email and that user (even if these pushers were created by a different client) --- CHANGES.md | 2 + changelog.d/10581.bugfix | 1 + docs/upgrade.md | 5 ++ synapse/handlers/auth.py | 5 +- synapse/storage/databases/main/pusher.py | 72 ++++++++++++++++++++++ .../delta/63/02delete_unlinked_email_pushers.sql | 20 ++++++ tests/push/test_email.py | 39 ++++++++++++ 7 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10581.bugfix create mode 100644 synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql (limited to 'tests') diff --git a/CHANGES.md b/CHANGES.md index f8da8771aa..24f3d53a6d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,5 @@ +Users will stop receiving message updates via email for addresses that were previously linked to their account + Synapse 1.41.0 (2021-08-24) =========================== diff --git a/changelog.d/10581.bugfix b/changelog.d/10581.bugfix new file mode 100644 index 0000000000..15c7da4497 --- /dev/null +++ b/changelog.d/10581.bugfix @@ -0,0 +1 @@ +Remove pushers when deleting a 3pid from an account. Pushers for old unlinked emails will also be deleted. \ No newline at end of file diff --git a/docs/upgrade.md b/docs/upgrade.md index 6d4b8cb48e..dcf0a7db5b 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -107,6 +107,11 @@ This may affect you if you make use of custom HTML templates for the The template is now provided an `error` variable if the authentication process failed. See the default templates linked above for an example. +# Upgrading to v1.42.0 + +## Removal of out-of-date email pushers +Users will stop receiving message updates via email for addresses that were +once, but not still, linked to their account. # Upgrading to v1.41.0 diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 98d3d2d97f..34725324a6 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1464,6 +1464,10 @@ class AuthHandler(BaseHandler): ) await self.store.user_delete_threepid(user_id, medium, address) + if medium == "email": + await self.store.delete_pusher_by_app_id_pushkey_user_id( + app_id="m.email", pushkey=address, user_id=user_id + ) return result async def hash(self, password: str) -> str: @@ -1732,7 +1736,6 @@ class AuthHandler(BaseHandler): @attr.s(slots=True) class MacaroonGenerator: - hs = attr.ib() def generate_guest_access_token(self, user_id: str) -> str: diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index b48fe086d4..e47caa2125 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -48,6 +48,11 @@ class PusherWorkerStore(SQLBaseStore): self._remove_stale_pushers, ) + self.db_pool.updates.register_background_update_handler( + "remove_deleted_email_pushers", + self._remove_deleted_email_pushers, + ) + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: """JSON-decode the data in the rows returned from the `pushers` table @@ -388,6 +393,73 @@ class PusherWorkerStore(SQLBaseStore): return number_deleted + async def _remove_deleted_email_pushers( + self, progress: dict, batch_size: int + ) -> int: + """A background update that deletes all pushers for deleted email addresses. + + In previous versions of synapse, when users deleted their email address, it didn't + also delete all the pushers for that email address. This background update removes + those to prevent unwanted emails. This should only need to be run once (when users + upgrade to v1.42.0 + + Args: + progress: dict used to store progress of this background update + batch_size: the maximum number of rows to retrieve in a single select query + + Returns: + The number of deleted rows + """ + + last_pusher = progress.get("last_pusher", 0) + + def _delete_pushers(txn) -> int: + + sql = """ + SELECT p.id, p.user_name, p.app_id, p.pushkey + FROM pushers AS p + LEFT JOIN user_threepids AS t + ON t.user_id = p.user_name + AND t.medium = 'email' + AND t.address = p.pushkey + WHERE t.user_id is NULL + AND p.app_id = 'm.email' + AND p.id > ? + ORDER BY p.id ASC + LIMIT ? + """ + + txn.execute(sql, (last_pusher, batch_size)) + + last = None + num_deleted = 0 + for row in txn: + last = row[0] + num_deleted += 1 + self.db_pool.simple_delete_txn( + txn, + "pushers", + {"user_name": row[1], "app_id": row[2], "pushkey": row[3]}, + ) + + if last is not None: + self.db_pool.updates._background_update_progress_txn( + txn, "remove_deleted_email_pushers", {"last_pusher": last} + ) + + return num_deleted + + number_deleted = await self.db_pool.runInteraction( + "_remove_deleted_email_pushers", _delete_pushers + ) + + if number_deleted < batch_size: + await self.db_pool.updates._end_background_update( + "remove_deleted_email_pushers" + ) + + return number_deleted + class PusherStore(PusherWorkerStore): def get_pushers_stream_token(self) -> int: diff --git a/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql new file mode 100644 index 0000000000..611c4b95cf --- /dev/null +++ b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql @@ -0,0 +1,20 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- We may not have deleted all pushers for emails that are no longer linked +-- to an account, so we set up a background job to delete them. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6302, 'remove_deleted_email_pushers', '{}'); diff --git a/tests/push/test_email.py b/tests/push/test_email.py index e0a3342088..eea07485a0 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -125,6 +125,8 @@ class EmailPusherTests(HomeserverTestCase): ) ) + self.auth_handler = hs.get_auth_handler() + def test_need_validated_email(self): """Test that we can only add an email pusher if the user has validated their email. @@ -305,6 +307,43 @@ class EmailPusherTests(HomeserverTestCase): # We should get emailed about that message self._check_for_mail() + def test_no_email_sent_after_removed(self): + # Create a simple room with two users + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.helper.invite( + room=room, + src=self.user_id, + tok=self.access_token, + targ=self.others[0].id, + ) + self.helper.join( + room=room, + user=self.others[0].id, + tok=self.others[0].token, + ) + + # The other user sends a single message. + self.helper.send(room, body="Hi!", tok=self.others[0].token) + + # We should get emailed about that message + self._check_for_mail() + + # disassociate the user's email address + self.get_success( + self.auth_handler.delete_threepid( + user_id=self.user_id, + medium="email", + address="a@example.com", + ) + ) + + # check that the pusher for that email address has been deleted + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + ) + pushers = list(pushers) + self.assertEqual(len(pushers), 0) + def _check_for_mail(self): """Check that the user receives an email notification""" -- cgit 1.5.1 From 40f619eaa54d2391deccec473fc0f655c379e766 Mon Sep 17 00:00:00 2001 From: Aaron Raimist Date: Thu, 26 Aug 2021 11:07:58 -0500 Subject: Validate new m.room.power_levels events (#10232) Signed-off-by: Aaron Raimist --- changelog.d/10232.bugfix | 1 + synapse/events/utils.py | 5 ++- synapse/events/validator.py | 77 ++++++++++++++++++++++++++++++++- synapse/python_dependencies.py | 3 +- tests/rest/client/test_power_levels.py | 78 ++++++++++++++++++++++++++++++++++ 5 files changed, 160 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10232.bugfix (limited to 'tests') diff --git a/changelog.d/10232.bugfix b/changelog.d/10232.bugfix new file mode 100644 index 0000000000..7be72271e0 --- /dev/null +++ b/changelog.d/10232.bugfix @@ -0,0 +1 @@ +Validate new `m.room.power_levels` events. Contributed by @aaronraimist. \ No newline at end of file diff --git a/synapse/events/utils.py b/synapse/events/utils.py index b6da2f60af..738a151cef 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -32,6 +32,9 @@ from . import EventBase # the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar" SPLIT_FIELD_REGEX = re.compile(r"(? EventBase: """Returns a pruned version of the given event, which removes all keys we @@ -505,7 +508,7 @@ def validate_canonicaljson(value: Any): * NaN, Infinity, -Infinity """ if isinstance(value, int): - if value <= -(2 ** 53) or 2 ** 53 <= value: + if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value: raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON) elif isinstance(value, float): diff --git a/synapse/events/validator.py b/synapse/events/validator.py index fa6987d7cb..33954b4f62 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -11,16 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import collections.abc from typing import Union +import jsonschema + from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import EventFormatVersions from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.events.utils import validate_canonicaljson +from synapse.events.utils import ( + CANONICALJSON_MAX_INT, + CANONICALJSON_MIN_INT, + validate_canonicaljson, +) from synapse.federation.federation_server import server_matches_acl_event from synapse.types import EventID, RoomID, UserID @@ -87,6 +93,29 @@ class EventValidator: 400, "Can't create an ACL event that denies the local server" ) + if event.type == EventTypes.PowerLevels: + try: + jsonschema.validate( + instance=event.content, + schema=POWER_LEVELS_SCHEMA, + cls=plValidator, + ) + except jsonschema.ValidationError as e: + if e.path: + # example: "users_default": '0' is not of type 'integer' + message = '"' + e.path[-1] + '": ' + e.message # noqa: B306 + # jsonschema.ValidationError.message is a valid attribute + else: + # example: '0' is not of type 'integer' + message = e.message # noqa: B306 + # jsonschema.ValidationError.message is a valid attribute + + raise SynapseError( + code=400, + msg=message, + errcode=Codes.BAD_JSON, + ) + def _validate_retention(self, event: EventBase): """Checks that an event that defines the retention policy for a room respects the format enforced by the spec. @@ -185,3 +214,47 @@ class EventValidator: def _ensure_state_event(self, event): if not event.is_state(): raise SynapseError(400, "'%s' must be state events" % (event.type,)) + + +POWER_LEVELS_SCHEMA = { + "type": "object", + "properties": { + "ban": {"$ref": "#/definitions/int"}, + "events": {"$ref": "#/definitions/objectOfInts"}, + "events_default": {"$ref": "#/definitions/int"}, + "invite": {"$ref": "#/definitions/int"}, + "kick": {"$ref": "#/definitions/int"}, + "notifications": {"$ref": "#/definitions/objectOfInts"}, + "redact": {"$ref": "#/definitions/int"}, + "state_default": {"$ref": "#/definitions/int"}, + "users": {"$ref": "#/definitions/objectOfInts"}, + "users_default": {"$ref": "#/definitions/int"}, + }, + "definitions": { + "int": { + "type": "integer", + "minimum": CANONICALJSON_MIN_INT, + "maximum": CANONICALJSON_MAX_INT, + }, + "objectOfInts": { + "type": "object", + "additionalProperties": {"$ref": "#/definitions/int"}, + }, + }, +} + + +def _create_power_level_validator(): + validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA) + + # by default jsonschema does not consider a frozendict to be an object so + # we need to use a custom type checker + # https://python-jsonschema.readthedocs.io/en/stable/validate/?highlight=object#validating-with-additional-types + type_checker = validator.TYPE_CHECKER.redefine( + "object", lambda checker, thing: isinstance(thing, collections.abc.Mapping) + ) + + return jsonschema.validators.extend(validator, type_checker=type_checker) + + +plValidator = _create_power_level_validator() diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index cdcbdd772b..154e5b7028 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -48,7 +48,8 @@ logger = logging.getLogger(__name__) # [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers. REQUIREMENTS = [ - "jsonschema>=2.5.1", + # we use the TYPE_CHECKER.redefine method added in jsonschema 3.0.0 + "jsonschema>=3.0.0", "frozendict>=1", "unpaddedbase64>=1.1.0", "canonicaljson>=1.4.0", diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index 91d0762cb0..c0de4c93a8 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.errors import Codes +from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT from synapse.rest import admin from synapse.rest.client import login, room, sync @@ -203,3 +205,79 @@ class PowerLevelsTestCase(HomeserverTestCase): tok=self.admin_access_token, expect_code=200, # expect success ) + + def test_cannot_set_string_power_levels(self): + room_power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, + ) + + # Update existing power levels with user at PL "0" + room_power_levels["users"].update({self.user_user_id: "0"}) + + body = self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + expect_code=400, # expect failure + ) + + self.assertEqual( + body["errcode"], + Codes.BAD_JSON, + body, + ) + + def test_cannot_set_unsafe_large_power_levels(self): + room_power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, + ) + + # Update existing power levels with user at PL above the max safe integer + room_power_levels["users"].update( + {self.user_user_id: CANONICALJSON_MAX_INT + 1} + ) + + body = self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + expect_code=400, # expect failure + ) + + self.assertEqual( + body["errcode"], + Codes.BAD_JSON, + body, + ) + + def test_cannot_set_unsafe_small_power_levels(self): + room_power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, + ) + + # Update existing power levels with user at PL below the minimum safe integer + room_power_levels["users"].update( + {self.user_user_id: CANONICALJSON_MIN_INT - 1} + ) + + body = self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + expect_code=400, # expect failure + ) + + self.assertEqual( + body["errcode"], + Codes.BAD_JSON, + body, + ) -- cgit 1.5.1 From 1800aabfc226938036479d2ab1a750aa34cf3974 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 26 Aug 2021 21:41:44 +0100 Subject: Split `FederationHandler` in half (#10692) The idea here is to take anything to do with incoming events and move it out to a separate handler, as a way of making FederationHandler smaller. --- changelog.d/10692.misc | 1 + synapse/federation/federation_server.py | 7 +- synapse/handlers/federation.py | 1789 +------------------- synapse/handlers/federation_event.py | 1825 +++++++++++++++++++++ synapse/replication/http/federation.py | 4 +- synapse/server.py | 5 + tests/federation/transport/test_knocking.py | 2 +- tests/handlers/test_federation.py | 16 +- tests/handlers/test_presence.py | 4 +- tests/replication/test_federation_sender_shard.py | 2 +- tests/test_federation.py | 10 +- 11 files changed, 1884 insertions(+), 1781 deletions(-) create mode 100644 changelog.d/10692.misc create mode 100644 synapse/handlers/federation_event.py (limited to 'tests') diff --git a/changelog.d/10692.misc b/changelog.d/10692.misc new file mode 100644 index 0000000000..a1b0def76b --- /dev/null +++ b/changelog.d/10692.misc @@ -0,0 +1 @@ +Split the event-processing methods in `FederationHandler` into a separate `FederationEventHandler`. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e1b58d40c5..214ee948fa 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -110,6 +110,7 @@ class FederationServer(FederationBase): super().__init__(hs) self.handler = hs.get_federation_handler() + self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() @@ -787,7 +788,9 @@ class FederationServer(FederationBase): event = await self._check_sigs_and_hash(room_version, event) - return await self.handler.on_send_membership_event(origin, event) + return await self._federation_event_handler.on_send_membership_event( + origin, event + ) async def on_event_auth( self, origin: str, room_id: str, event_id: str @@ -1005,7 +1008,7 @@ class FederationServer(FederationBase): async with lock: logger.info("handling received PDU: %s", event) try: - await self.handler.on_receive_pdu(origin, event) + await self._federation_event_handler.on_receive_pdu(origin, event) except FederationError as e: # XXX: Ideally we'd inform the remote we failed to process # the event, but we can't return an error in the transaction diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 6fa2fc8f52..daf1d3bfb3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -17,23 +17,9 @@ import itertools import logging -from collections.abc import Container from http import HTTPStatus -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - Iterable, - List, - Optional, - Sequence, - Set, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union -import attr -from prometheus_client import Counter from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 @@ -41,19 +27,12 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer from synapse import event_auth -from synapse.api.constants import ( - EventContentFields, - EventTypes, - Membership, - RejectedReason, - RoomEncryptionAlgorithms, -) +from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.errors import ( AuthError, CodeMessageException, Codes, FederationDeniedError, - FederationError, HttpResponseException, NotFoundError, RequestSendFailed, @@ -61,7 +40,6 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions from synapse.crypto.event_signing import compute_event_signature -from synapse.event_auth import auth_types_for_event from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator @@ -75,28 +53,14 @@ from synapse.logging.context import ( run_in_background, ) from synapse.logging.utils import log_function -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, - ReplicationFederationSendEventsRestServlet, ReplicationStoreRoomOnOutlierMembershipRestServlet, ) -from synapse.state import StateResolutionStore from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import ( - JsonDict, - MutableStateMap, - PersistedEventPosition, - RoomStreamToken, - StateMap, - UserID, - get_domain_from_id, -) -from synapse.util.async_helpers import Linearizer, concurrently_execute -from synapse.util.iterutils import batch_iter +from synapse.types import JsonDict, StateMap, get_domain_from_id +from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination -from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server if TYPE_CHECKING: @@ -104,45 +68,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -soft_failed_event_counter = Counter( - "synapse_federation_soft_failed_events_total", - "Events received over federation that we marked as soft_failed", -) - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _NewEventInfo: - """Holds information about a received event, ready for passing to _auth_and_persist_events - - Attributes: - event: the received event - - claimed_auth_event_map: a map of (type, state_key) => event for the event's - claimed auth_events. - - This can include events which have not yet been persisted, in the case that - we are backfilling a batch of events. - - Note: May be incomplete: if we were unable to find all of the claimed auth - events. Also, treat the contents with caution: the events might also have - been rejected, might not yet have been authorized themselves, or they might - be in the wrong room. - - """ - - event: EventBase - claimed_auth_event_map: StateMap[EventBase] - class FederationHandler(BaseHandler): - """Handles events that originated from federation. - Responsible for: - a) handling received Pdus before handing them on as Events to the rest - of the homeserver (including auth and state conflict resolutions) - b) converting events that were produced by local clients that may need - to be sent to remote homeservers. - c) doing the necessary dances to invite remote users and join remote - rooms. + """Handles general incoming federation requests + + Incoming events are *not* handled here, for which see FederationEventHandler. """ def __init__(self, hs: "HomeServer"): @@ -155,652 +85,35 @@ class FederationHandler(BaseHandler): self.state_store = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() - self._state_resolution_handler = hs.get_state_resolution_handler() self.server_name = hs.hostname self.keyring = hs.get_keyring() - self.action_generator = hs.get_action_generator() self.is_mine_id = hs.is_mine_id self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self._event_auth_handler = hs.get_event_auth_handler() - self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_proxied_blacklisted_http_client() - self._instance_name = hs.get_instance_name() self._replication = hs.get_replication_data_handler() + self._federation_event_handler = hs.get_federation_event_handler() - self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( hs ) if hs.config.worker_app: - self._user_device_resync = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) - ) self._maybe_store_room_on_outlier_membership = ( ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs) ) else: - self._device_list_updater = hs.get_device_handler().device_list_updater self._maybe_store_room_on_outlier_membership = ( self.store.maybe_store_room_on_outlier_membership ) - # When joining a room we need to queue any events for that room up. - # For each room, a list of (pdu, origin) tuples. - self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {} - self._room_pdu_linearizer = Linearizer("fed_room_pdu") - self._room_backfill = Linearizer("room_backfill") self.third_party_event_rules = hs.get_third_party_event_rules() - self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - - async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: - """Process a PDU received via a federation /send/ transaction - - Args: - origin: server which initiated the /send/ transaction. Will - be used to fetch missing events or state. - pdu: received PDU - """ - - room_id = pdu.room_id - event_id = pdu.event_id - - # We reprocess pdus when we have seen them only as outliers - existing = await self.store.get_event( - event_id, allow_none=True, allow_rejected=True - ) - - # FIXME: Currently we fetch an event again when we already have it - # if it has been marked as an outlier. - if existing: - if not existing.internal_metadata.is_outlier(): - logger.info( - "Ignoring received event %s which we have already seen", event_id - ) - return - if pdu.internal_metadata.is_outlier(): - logger.info( - "Ignoring received outlier %s which we already have as an outlier", - event_id, - ) - return - logger.info("De-outliering event %s", event_id) - - # do some initial sanity-checking of the event. In particular, make - # sure it doesn't have hundreds of prev_events or auth_events, which - # could cause a huge state resolution or cascade of event fetches. - try: - self._sanity_check_event(pdu) - except SynapseError as err: - logger.warning("Received event failed sanity checks") - raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) - - # If we are currently in the process of joining this room, then we - # queue up events for later processing. - if room_id in self.room_queues: - logger.info( - "Queuing PDU from %s for now: join in progress", - origin, - ) - self.room_queues[room_id].append((pdu, origin)) - return - - # If we're not in the room just ditch the event entirely. This is - # probably an old server that has come back and thinks we're still in - # the room (or we've been rejoined to the room by a state reset). - # - # Note that if we were never in the room then we would have already - # dropped the event, since we wouldn't know the room version. - is_in_room = await self._event_auth_handler.check_host_in_room( - room_id, self.server_name - ) - if not is_in_room: - logger.info( - "Ignoring PDU from %s as we're not in the room", - origin, - ) - return None - - # Check that the event passes auth based on the state at the event. This is - # done for events that are to be added to the timeline (non-outliers). - # - # Get missing pdus if necessary: - # - Fetching any missing prev events to fill in gaps in the graph - # - Fetching state if we have a hole in the graph - if not pdu.internal_metadata.is_outlier(): - prevs = set(pdu.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if missing_prevs: - # We only backfill backwards to the min depth. - min_depth = await self.get_min_depth_for_context(pdu.room_id) - logger.debug("min_depth: %d", min_depth) - - if min_depth is not None and pdu.depth > min_depth: - # If we're missing stuff, ensure we only fetch stuff one - # at a time. - logger.info( - "Acquiring room lock to fetch %d missing prev_events: %s", - len(missing_prevs), - shortstr(missing_prevs), - ) - with (await self._room_pdu_linearizer.queue(pdu.room_id)): - logger.info( - "Acquired room lock to fetch %d missing prev_events", - len(missing_prevs), - ) - - try: - await self._get_missing_events_for_pdu( - origin, pdu, prevs, min_depth - ) - except Exception as e: - raise Exception( - "Error fetching missing prev_events for %s: %s" - % (event_id, e) - ) from e - - # Update the set of things we've seen after trying to - # fetch the missing stuff - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if not missing_prevs: - logger.info("Found all missing prev_events") - - if missing_prevs: - # since this event was pushed to us, it is possible for it to - # become the only forward-extremity in the room, and we would then - # trust its state to be the state for the whole room. This is very - # bad. Further, if the event was pushed to us, there is no excuse - # for us not to have all the prev_events. (XXX: apart from - # min_depth?) - # - # We therefore reject any such events. - logger.warning( - "Rejecting: failed to fetch %d prev events: %s", - len(missing_prevs), - shortstr(missing_prevs), - ) - raise FederationError( - "ERROR", - 403, - ( - "Your server isn't divulging details about prev_events " - "referenced in this event." - ), - affected=pdu.event_id, - ) - - await self._process_received_pdu(origin, pdu, state=None) - - async def _get_missing_events_for_pdu( - self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int - ) -> None: - """ - Args: - origin: Origin of the pdu. Will be called to get the missing events - pdu: received pdu - prevs: List of event ids which we are missing - min_depth: Minimum depth of events to return. - """ - - room_id = pdu.room_id - event_id = pdu.event_id - - seen = await self.store.have_events_in_timeline(prevs) - - if not prevs - seen: - return - - latest_list = await self.store.get_latest_event_ids_in_room(room_id) - - # We add the prev events that we have seen to the latest - # list to ensure the remote server doesn't give them to us - latest = set(latest_list) - latest |= seen - - logger.info( - "Requesting missing events between %s and %s", - shortstr(latest), - event_id, - ) - - # XXX: we set timeout to 10s to help workaround - # https://github.com/matrix-org/synapse/issues/1733. - # The reason is to avoid holding the linearizer lock - # whilst processing inbound /send transactions, causing - # FDs to stack up and block other inbound transactions - # which empirically can currently take up to 30 minutes. - # - # N.B. this explicitly disables retry attempts. - # - # N.B. this also increases our chances of falling back to - # fetching fresh state for the room if the missing event - # can't be found, which slightly reduces our security. - # it may also increase our DAG extremity count for the room, - # causing additional state resolution? See #1760. - # However, fetching state doesn't hold the linearizer lock - # apparently. - # - # see https://github.com/matrix-org/synapse/pull/1744 - # - # ---- - # - # Update richvdh 2018/09/18: There are a number of problems with timing this - # request out aggressively on the client side: - # - # - it plays badly with the server-side rate-limiter, which starts tarpitting you - # if you send too many requests at once, so you end up with the server carefully - # working through the backlog of your requests, which you have already timed - # out. - # - # - for this request in particular, we now (as of - # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the - # server can't produce a plausible-looking set of prev_events - so we becone - # much more likely to reject the event. - # - # - contrary to what it says above, we do *not* fall back to fetching fresh state - # for the room if get_missing_events times out. Rather, we give up processing - # the PDU whose prevs we are missing, which then makes it much more likely that - # we'll end up back here for the *next* PDU in the list, which exacerbates the - # problem. - # - # - the aggressive 10s timeout was introduced to deal with incoming federation - # requests taking 8 hours to process. It's not entirely clear why that was going - # on; certainly there were other issues causing traffic storms which are now - # resolved, and I think in any case we may be more sensible about our locking - # now. We're *certainly* more sensible about our logging. - # - # All that said: Let's try increasing the timeout to 60s and see what happens. - - try: - missing_events = await self.federation_client.get_missing_events( - origin, - room_id, - earliest_events_ids=list(latest), - latest_events=[pdu], - limit=10, - min_depth=min_depth, - timeout=60000, - ) - except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e: - # We failed to get the missing events, but since we need to handle - # the case of `get_missing_events` not returning the necessary - # events anyway, it is safe to simply log the error and continue. - logger.warning("Failed to get prev_events: %s", e) - return - - logger.info("Got %d prev_events", len(missing_events)) - await self._process_pulled_events(origin, missing_events, backfilled=False) - - async def _get_state_after_missing_prev_event( - self, - destination: str, - room_id: str, - event_id: str, - ) -> List[EventBase]: - """Requests all of the room state at a given event from a remote homeserver. - - Args: - destination: The remote homeserver to query for the state. - room_id: The id of the room we're interested in. - event_id: The id of the event we want the state at. - - Returns: - A list of events in the state, including the event itself - """ - ( - state_event_ids, - auth_event_ids, - ) = await self.federation_client.get_room_state_ids( - destination, room_id, event_id=event_id - ) - - logger.debug( - "state_ids returned %i state events, %i auth events", - len(state_event_ids), - len(auth_event_ids), - ) - - # start by just trying to fetch the events from the store - desired_events = set(state_event_ids) - desired_events.add(event_id) - logger.debug("Fetching %i events from cache/store", len(desired_events)) - fetched_events = await self.store.get_events( - desired_events, allow_rejected=True - ) - - missing_desired_events = desired_events - fetched_events.keys() - logger.debug( - "We are missing %i events (got %i)", - len(missing_desired_events), - len(fetched_events), - ) - - # We probably won't need most of the auth events, so let's just check which - # we have for now, rather than thrashing the event cache with them all - # unnecessarily. - - # TODO: we probably won't actually need all of the auth events, since we - # already have a bunch of the state events. It would be nice if the - # federation api gave us a way of finding out which we actually need. - - missing_auth_events = set(auth_event_ids) - fetched_events.keys() - missing_auth_events.difference_update( - await self.store.have_seen_events(room_id, missing_auth_events) - ) - logger.debug("We are also missing %i auth events", len(missing_auth_events)) - - missing_events = missing_desired_events | missing_auth_events - logger.debug("Fetching %i events from remote", len(missing_events)) - await self._get_events_and_persist( - destination=destination, room_id=room_id, events=missing_events - ) - - # we need to make sure we re-load from the database to get the rejected - # state correct. - fetched_events.update( - await self.store.get_events(missing_desired_events, allow_rejected=True) - ) - - # check for events which were in the wrong room. - # - # this can happen if a remote server claims that the state or - # auth_events at an event in room A are actually events in room B - - bad_events = [ - (event_id, event.room_id) - for event_id, event in fetched_events.items() - if event.room_id != room_id - ] - - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, - ) - - del fetched_events[bad_event_id] - - # if we couldn't get the prev event in question, that's a problem. - remote_event = fetched_events.get(event_id) - if not remote_event: - raise Exception("Unable to get missing prev_event %s" % (event_id,)) - - # missing state at that event is a warning, not a blocker - # XXX: this doesn't sound right? it means that we'll end up with incomplete - # state. - failed_to_fetch = desired_events - fetched_events.keys() - if failed_to_fetch: - logger.warning( - "Failed to fetch missing state events for %s %s", - event_id, - failed_to_fetch, - ) - - remote_state = [ - fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events - ] - - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) - - return remote_state - - async def _process_received_pdu( - self, - origin: str, - event: EventBase, - state: Optional[Iterable[EventBase]], - backfilled: bool = False, - ) -> None: - """Called when we have a new pdu. We need to do auth checks and put it - through the StateHandler. - - Args: - origin: server sending the event - - event: event to be persisted - - state: Normally None, but if we are handling a gap in the graph - (ie, we are missing one or more prev_events), the resolved state at the - event - - backfilled: True if this is part of a historical batch of events (inhibits - notification to clients, and validation of device keys.) - """ - logger.debug("Processing event: %s", event) - - try: - context = await self.state_handler.compute_event_context( - event, old_state=state - ) - await self._auth_and_persist_event( - origin, event, context, state=state, backfilled=backfilled - ) - except AuthError as e: - raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) - - if backfilled: - return - - # For encrypted messages we check that we know about the sending device, - # if we don't then we mark the device cache for that user as stale. - if event.type == EventTypes.Encrypted: - device_id = event.content.get("device_id") - sender_key = event.content.get("sender_key") - - cached_devices = await self.store.get_cached_devices_for_user(event.sender) - - resync = False # Whether we should resync device lists. - - device = None - if device_id is not None: - device = cached_devices.get(device_id) - if device is None: - logger.info( - "Received event from remote device not in our cache: %s %s", - event.sender, - device_id, - ) - resync = True - - # We also check if the `sender_key` matches what we expect. - if sender_key is not None: - # Figure out what sender key we're expecting. If we know the - # device and recognize the algorithm then we can work out the - # exact key to expect. Otherwise check it matches any key we - # have for that device. - - current_keys: Container[str] = [] - - if device: - keys = device.get("keys", {}).get("keys", {}) - - if ( - event.content.get("algorithm") - == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2 - ): - # For this algorithm we expect a curve25519 key. - key_name = "curve25519:%s" % (device_id,) - current_keys = [keys.get(key_name)] - else: - # We don't know understand the algorithm, so we just - # check it matches a key for the device. - current_keys = keys.values() - elif device_id: - # We don't have any keys for the device ID. - pass - else: - # The event didn't include a device ID, so we just look for - # keys across all devices. - current_keys = [ - key - for device in cached_devices.values() - for key in device.get("keys", {}).get("keys", {}).values() - ] - - # We now check that the sender key matches (one of) the expected - # keys. - if sender_key not in current_keys: - logger.info( - "Received event from remote device with unexpected sender key: %s %s: %s", - event.sender, - device_id or "", - sender_key, - ) - resync = True - - if resync: - run_as_background_process( - "resync_device_due_to_pdu", self._resync_device, event.sender - ) - - await self._handle_marker_event(origin, event) - - async def _handle_marker_event(self, origin: str, marker_event: EventBase): - """Handles backfilling the insertion event when we receive a marker - event that points to one. - - Args: - origin: Origin of the event. Will be called to get the insertion event - marker_event: The event to process - """ - - if marker_event.type != EventTypes.MSC2716_MARKER: - # Not a marker event - return - - if marker_event.rejected_reason is not None: - # Rejected event - return - - # Skip processing a marker event if the room version doesn't - # support it. - room_version = await self.store.get_room_version(marker_event.room_id) - if not room_version.msc2716_historical: - return - - logger.debug("_handle_marker_event: received %s", marker_event) - - insertion_event_id = marker_event.content.get( - EventContentFields.MSC2716_MARKER_INSERTION - ) - - if insertion_event_id is None: - # Nothing to retrieve then (invalid marker) - return - - logger.debug( - "_handle_marker_event: backfilling insertion event %s", insertion_event_id - ) - - await self._get_events_and_persist( - origin, - marker_event.room_id, - [insertion_event_id], - ) - - insertion_event = await self.store.get_event( - insertion_event_id, allow_none=True - ) - if insertion_event is None: - logger.warning( - "_handle_marker_event: server %s didn't return insertion event %s for marker %s", - origin, - insertion_event_id, - marker_event.event_id, - ) - return - - logger.debug( - "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s", - insertion_event, - marker_event, - ) - - await self.store.insert_insertion_extremity( - insertion_event_id, marker_event.room_id - ) - - logger.debug( - "_handle_marker_event: insertion extremity added for %s from marker event %s", - insertion_event, - marker_event, - ) - - async def _resync_device(self, sender: str) -> None: - """We have detected that the device list for the given user may be out - of sync, so we try and resync them. - """ - - try: - await self.store.mark_remote_user_device_cache_as_stale(sender) - - # Immediately attempt a resync in the background - if self.config.worker_app: - await self._user_device_resync(user_id=sender) - else: - await self._device_list_updater.user_device_resync(sender) - except Exception: - logger.exception("Failed to resync device for %s", sender) - - @log_function - async def backfill( - self, dest: str, room_id: str, limit: int, extremities: List[str] - ) -> None: - """Trigger a backfill request to `dest` for the given `room_id` - - This will attempt to get more events from the remote. If the other side - has no new events to offer, this will return an empty list. - - As the events are received, we check their signatures, and also do some - sanity-checking on them. If any of the backfilled events are invalid, - this method throws a SynapseError. - - We might also raise an InvalidResponseError if the response from the remote - server is just bogus. - - TODO: make this more useful to distinguish failures of the remote - server from invalid events (there is probably no point in trying to - re-fetch invalid events from every other HS in the room.) - """ - if dest == self.server_name: - raise SynapseError(400, "Can't backfill from self.") - - events = await self.federation_client.backfill( - dest, room_id, limit=limit, extremities=extremities - ) - - if not events: - return - - # if there are any events in the wrong room, the remote server is buggy and - # should not be trusted. - for ev in events: - if ev.room_id != room_id: - raise InvalidResponseError( - f"Remote server {dest} returned event {ev.event_id} which is in " - f"room {ev.room_id}, when we were backfilling in {room_id}" - ) - - await self._process_pulled_events(dest, events, backfilled=True) - async def maybe_backfill( self, room_id: str, current_depth: int, limit: int ) -> bool: @@ -995,7 +308,7 @@ class FederationHandler(BaseHandler): # TODO: Should we try multiple of these at a time? for dom in domains: try: - await self.backfill( + await self._federation_event_handler.backfill( dom, room_id, limit=100, extremities=extremities ) # If this succeeded then we probably already have the @@ -1084,316 +397,7 @@ class FederationHandler(BaseHandler): tried_domains.update(dom for dom, _ in likely_extremeties_domains) - return False - - async def _get_events_and_persist( - self, destination: str, room_id: str, events: Iterable[str] - ) -> None: - """Fetch the given events from a server, and persist them as outliers. - - This function *does not* recursively get missing auth events of the - newly fetched events. Callers must include in the `events` argument - any missing events from the auth chain. - - Logs a warning if we can't find the given event. - """ - - room_version = await self.store.get_room_version(room_id) - - event_map: Dict[str, EventBase] = {} - - async def get_event(event_id: str): - with nested_logging_context(event_id): - try: - event = await self.federation_client.get_pdu( - [destination], - event_id, - room_version, - outlier=True, - ) - if event is None: - logger.warning( - "Server %s didn't return event %s", - destination, - event_id, - ) - return - - event_map[event.event_id] = event - - except Exception as e: - logger.warning( - "Error fetching missing state/auth event %s: %s %s", - event_id, - type(e), - e, - ) - - await concurrently_execute(get_event, events, 5) - - # Make a map of auth events for each event. We do this after fetching - # all the events as some of the events' auth events will be in the list - # of requested events. - - auth_events = [ - aid - for event in event_map.values() - for aid in event.auth_event_ids() - if aid not in event_map - ] - persisted_events = await self.store.get_events( - auth_events, - allow_rejected=True, - ) - - event_infos = [] - for event in event_map.values(): - auth = {} - for auth_event_id in event.auth_event_ids(): - ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id) - if ae: - auth[(ae.type, ae.state_key)] = ae - else: - logger.info("Missing auth event %s", auth_event_id) - - event_infos.append(_NewEventInfo(event, auth)) - - if event_infos: - await self._auth_and_persist_events( - destination, - room_id, - event_infos, - ) - - async def _process_pulled_events( - self, origin: str, events: Iterable[EventBase], backfilled: bool - ) -> None: - """Process a batch of events we have pulled from a remote server - - Pulls in any events required to auth the events, persists the received events, - and notifies clients, if appropriate. - - Assumes the events have already had their signatures and hashes checked. - - Params: - origin: The server we received these events from - events: The received events. - backfilled: True if this is part of a historical batch of events (inhibits - notification to clients, and validation of device keys.) - """ - - # We want to sort these by depth so we process them and - # tell clients about them in order. - sorted_events = sorted(events, key=lambda x: x.depth) - - for ev in sorted_events: - with nested_logging_context(ev.event_id): - await self._process_pulled_event(origin, ev, backfilled=backfilled) - - async def _process_pulled_event( - self, origin: str, event: EventBase, backfilled: bool - ) -> None: - """Process a single event that we have pulled from a remote server - - Pulls in any events required to auth the event, persists the received event, - and notifies clients, if appropriate. - - Assumes the event has already had its signatures and hashes checked. - - This is somewhat equivalent to on_receive_pdu, but applies somewhat different - logic in the case that we are missing prev_events (in particular, it just - requests the state at that point, rather than triggering a get_missing_events) - - so is appropriate when we have pulled the event from a remote server, rather - than having it pushed to us. - - Params: - origin: The server we received this event from - events: The received event - backfilled: True if this is part of a historical batch of events (inhibits - notification to clients, and validation of device keys.) - """ - logger.info("Processing pulled event %s", event) - - # these should not be outliers. - assert not event.internal_metadata.is_outlier() - - event_id = event.event_id - - existing = await self.store.get_event( - event_id, allow_none=True, allow_rejected=True - ) - if existing: - if not existing.internal_metadata.is_outlier(): - logger.info( - "Ignoring received event %s which we have already seen", - event_id, - ) - return - logger.info("De-outliering event %s", event_id) - - try: - self._sanity_check_event(event) - except SynapseError as err: - logger.warning("Event %s failed sanity check: %s", event_id, err) - return - - try: - state = await self._resolve_state_at_missing_prevs(origin, event) - await self._process_received_pdu( - origin, event, state=state, backfilled=backfilled - ) - except FederationError as e: - if e.code == 403: - logger.warning("Pulled event %s failed history check.", event_id) - else: - raise - - async def _resolve_state_at_missing_prevs( - self, dest: str, event: EventBase - ) -> Optional[Iterable[EventBase]]: - """Calculate the state at an event with missing prev_events. - - This is used when we have pulled a batch of events from a remote server, and - still don't have all the prev_events. - - If we already have all the prev_events for `event`, this method does nothing. - - Otherwise, the missing prevs become new backwards extremities, and we fall back - to asking the remote server for the state after each missing `prev_event`, - and resolving across them. - - That's ok provided we then resolve the state against other bits of the DAG - before using it - in other words, that the received event `event` is not going - to become the only forwards_extremity in the room (which will ensure that you - can't just take over a room by sending an event, withholding its prev_events, - and declaring yourself to be an admin in the subsequent state request). - - In other words: we should only call this method if `event` has been *pulled* - as part of a batch of missing prev events, or similar. - - Params: - dest: the remote server to ask for state at the missing prevs. Typically, - this will be the server we got `event` from. - event: an event to check for missing prevs. - - Returns: - if we already had all the prev events, `None`. Otherwise, returns a list of - the events in the state at `event`. - """ - room_id = event.room_id - event_id = event.event_id - - prevs = set(event.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if not missing_prevs: - return None - - logger.info( - "Event %s is missing prev_events %s: calculating state for a " - "backwards extremity", - event_id, - shortstr(missing_prevs), - ) - # Calculate the state after each of the previous events, and - # resolve them to find the correct state at the current event. - event_map = {event_id: event} - try: - # Get the state of the events we know about - ours = await self.state_store.get_state_groups_ids(room_id, seen) - - # state_maps is a list of mappings from (type, state_key) to event_id - state_maps: List[StateMap[str]] = list(ours.values()) - - # we don't need this any more, let's delete it. - del ours - - # Ask the remote server for the states we don't - # know about - for p in missing_prevs: - logger.info("Requesting state after missing prev_event %s", p) - - with nested_logging_context(p): - # note that if any of the missing prevs share missing state or - # auth events, the requests to fetch those events are deduped - # by the get_pdu_cache in federation_client. - remote_state = await self._get_state_after_missing_prev_event( - dest, room_id, p - ) - - remote_state_map = { - (x.type, x.state_key): x.event_id for x in remote_state - } - state_maps.append(remote_state_map) - - for x in remote_state: - event_map[x.event_id] = x - - room_version = await self.store.get_room_version_id(room_id) - state_map = await self._state_resolution_handler.resolve_events_with_store( - room_id, - room_version, - state_maps, - event_map, - state_res_store=StateResolutionStore(self.store), - ) - - # We need to give _process_received_pdu the actual state events - # rather than event ids, so generate that now. - - # First though we need to fetch all the events that are in - # state_map, so we can build up the state below. - evs = await self.store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.AS_IS, - ) - event_map.update(evs) - - state = [event_map[e] for e in state_map.values()] - except Exception: - logger.warning( - "Error attempting to resolve state at missing prev_events", - exc_info=True, - ) - raise FederationError( - "ERROR", - 403, - "We can't get valid state history.", - affected=event_id, - ) - return state - - def _sanity_check_event(self, ev: EventBase) -> None: - """ - Do some early sanity checks of a received event - - In particular, checks it doesn't have an excessive number of - prev_events or auth_events, which could cause a huge state resolution - or cascade of event fetches. - - Args: - ev: event to be checked - - Raises: - SynapseError if the event does not pass muster - """ - if len(ev.prev_event_ids()) > 20: - logger.warning( - "Rejecting event %s which has %i prev_events", - ev.event_id, - len(ev.prev_event_ids()), - ) - raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events") - - if len(ev.auth_event_ids()) > 10: - logger.warning( - "Rejecting event %s which has %i auth_events", - ev.event_id, - len(ev.auth_event_ids()), - ) - raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") + return False async def send_invite(self, target_host: str, event: EventBase) -> EventBase: """Sends the invite to the remote server for signing. @@ -1460,9 +464,9 @@ class FederationHandler(BaseHandler): # This shouldn't happen, because the RoomMemberHandler has a # linearizer lock which only allows one operation per user per room # at a time - so this is just paranoia. - assert room_id not in self.room_queues + assert room_id not in self._federation_event_handler.room_queues - self.room_queues[room_id] = [] + self._federation_event_handler.room_queues[room_id] = [] await self._clean_room_for_join(room_id) @@ -1536,8 +540,8 @@ class FederationHandler(BaseHandler): logger.debug("Finished joining %s to %s", joinee, room_id) return event.event_id, max_stream_id finally: - room_queue = self.room_queues[room_id] - del self.room_queues[room_id] + room_queue = self._federation_event_handler.room_queues[room_id] + del self._federation_event_handler.room_queues[room_id] # we don't need to wait for the queued events to be processed - # it's just a best-effort thing at this point. We do want to do @@ -1613,7 +617,7 @@ class FederationHandler(BaseHandler): event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify( + stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) return event.event_id, stream_id @@ -1633,7 +637,7 @@ class FederationHandler(BaseHandler): p, ) with nested_logging_context(p.event_id): - await self.on_receive_pdu(origin, p) + await self._federation_event_handler.on_receive_pdu(origin, p) except Exception as e: logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e @@ -1726,7 +730,7 @@ class FederationHandler(BaseHandler): raise # Ensure the user can even join the room. - await self._check_join_restrictions(context, event) + await self._federation_event_handler.check_join_restrictions(context, event) # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` @@ -1803,7 +807,9 @@ class FederationHandler(BaseHandler): ) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify(event.room_id, [(event, context)]) + await self._federation_event_handler.persist_events_and_notify( + event.room_id, [(event, context)] + ) return event @@ -1830,7 +836,7 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify( + stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -1973,116 +979,6 @@ class FederationHandler(BaseHandler): return event - @log_function - async def on_send_membership_event( - self, origin: str, event: EventBase - ) -> Tuple[EventBase, EventContext]: - """ - We have received a join/leave/knock event for a room via send_join/leave/knock. - - Verify that event and send it into the room on the remote homeserver's behalf. - - This is quite similar to on_receive_pdu, with the following principal - differences: - * only membership events are permitted (and only events with - sender==state_key -- ie, no kicks or bans) - * *We* send out the event on behalf of the remote server. - * We enforce the membership restrictions of restricted rooms. - * Rejected events result in an exception rather than being stored. - - There are also other differences, however it is not clear if these are by - design or omission. In particular, we do not attempt to backfill any missing - prev_events. - - Args: - origin: The homeserver of the remote (joining/invited/knocking) user. - event: The member event that has been signed by the remote homeserver. - - Returns: - The event and context of the event after inserting it into the room graph. - - Raises: - SynapseError if the event is not accepted into the room - """ - logger.debug( - "on_send_membership_event: Got event: %s, signatures: %s", - event.event_id, - event.signatures, - ) - - if get_domain_from_id(event.sender) != origin: - logger.info( - "Got send_membership request for user %r from different origin %s", - event.sender, - origin, - ) - raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - - if event.sender != event.state_key: - raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON) - - assert not event.internal_metadata.outlier - - # Send this event on behalf of the other server. - # - # The remote server isn't a full participant in the room at this point, so - # may not have an up-to-date list of the other homeservers participating in - # the room, so we send it on their behalf. - event.internal_metadata.send_on_behalf_of = origin - - context = await self.state_handler.compute_event_context(event) - context = await self._check_event_auth(origin, event, context) - if context.rejected: - raise SynapseError( - 403, f"{event.membership} event was rejected", Codes.FORBIDDEN - ) - - # for joins, we need to check the restrictions of restricted rooms - if event.membership == Membership.JOIN: - await self._check_join_restrictions(context, event) - - # for knock events, we run the third-party event rules. It's not entirely clear - # why we don't do this for other sorts of membership events. - if event.membership == Membership.KNOCK: - event_allowed, _ = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.info("Sending of knock %s forbidden by third-party rules", event) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) - - # all looks good, we can persist the event. - await self._run_push_actions_and_persist_event(event, context) - return event, context - - async def _check_join_restrictions( - self, context: EventContext, event: EventBase - ) -> None: - """Check that restrictions in restricted join rules are matched - - Called when we receive a join event via send_join. - - Raises an auth error if the restrictions are not matched. - """ - prev_state_ids = await context.get_prev_state_ids() - - # Check if the user is already in the room or invited to the room. - user_id = event.state_key - prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) - prev_member_event = None - if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) - - # Check if the member should be allowed access via membership in a space. - await self._event_auth_handler.check_restricted_join_rules( - prev_state_ids, - event.room_version, - user_id, - prev_member_event, - ) - async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: """Returns the state at the event. i.e. not including said event.""" @@ -2183,126 +1079,6 @@ class FederationHandler(BaseHandler): else: return None - async def get_min_depth_for_context(self, context: str) -> int: - return await self.store.get_min_depth(context) - - async def _auth_and_persist_event( - self, - origin: str, - event: EventBase, - context: EventContext, - state: Optional[Iterable[EventBase]] = None, - claimed_auth_event_map: Optional[StateMap[EventBase]] = None, - backfilled: bool = False, - ) -> None: - """ - Process an event by performing auth checks and then persisting to the database. - - Args: - origin: The host the event originates from. - event: The event itself. - context: - The event context. - - state: - The state events used to check the event for soft-fail. If this is - not provided the current state events will be used. - - claimed_auth_event_map: - A map of (type, state_key) => event for the event's claimed auth_events. - Possibly incomplete, and possibly including events that are not yet - persisted, or authed, or in the right room. - - Only populated where we may not already have persisted these events - - for example, when populating outliers. - - backfilled: True if the event was backfilled. - """ - context = await self._check_event_auth( - origin, - event, - context, - state=state, - claimed_auth_event_map=claimed_auth_event_map, - backfilled=backfilled, - ) - - await self._run_push_actions_and_persist_event(event, context, backfilled) - - async def _run_push_actions_and_persist_event( - self, event: EventBase, context: EventContext, backfilled: bool = False - ): - """Run the push actions for a received event, and persist it. - - Args: - event: The event itself. - context: The event context. - backfilled: True if the event was backfilled. - """ - try: - if ( - not event.internal_metadata.is_outlier() - and not backfilled - and not context.rejected - and (await self.store.get_min_depth(event.room_id)) <= event.depth - ): - await self.action_generator.handle_push_actions_for_event( - event, context - ) - - await self.persist_events_and_notify( - event.room_id, [(event, context)], backfilled=backfilled - ) - except Exception: - run_in_background( - self.store.remove_push_actions_from_staging, event.event_id - ) - raise - - async def _auth_and_persist_events( - self, - origin: str, - room_id: str, - event_infos: Collection[_NewEventInfo], - ) -> None: - """Creates the appropriate contexts and persists events. The events - should not depend on one another, e.g. this should be used to persist - a bunch of outliers, but not a chunk of individual events that depend - on each other for state calculations. - - Notifies about the events where appropriate. - """ - - if not event_infos: - return - - async def prep(ev_info: _NewEventInfo): - event = ev_info.event - with nested_logging_context(suffix=event.event_id): - res = await self.state_handler.compute_event_context(event) - res = await self._check_event_auth( - origin, - event, - res, - claimed_auth_event_map=ev_info.claimed_auth_event_map, - ) - return res - - contexts = await make_deferred_yieldable( - defer.gatherResults( - [run_in_background(prep, ev_info) for ev_info in event_infos], - consumeErrors=True, - ) - ) - - await self.persist_events_and_notify( - room_id, - [ - (ev_info.event, context) - for ev_info, context in zip(event_infos, contexts) - ], - ) - async def _persist_auth_tree( self, origin: str, @@ -2400,7 +1176,7 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR if auth_events or state: - await self.persist_events_and_notify( + await self._federation_event_handler.persist_events_and_notify( room_id, [ (e, events_to_context[e.event_id]) @@ -2412,108 +1188,10 @@ class FederationHandler(BaseHandler): event, old_state=state ) - return await self.persist_events_and_notify( + return await self._federation_event_handler.persist_events_and_notify( room_id, [(event, new_event_context)] ) - async def _check_for_soft_fail( - self, - event: EventBase, - state: Optional[Iterable[EventBase]], - backfilled: bool, - origin: str, - ) -> None: - """Checks if we should soft fail the event; if so, marks the event as - such. - - Args: - event - state: The state at the event if we don't have all the event's prev events - backfilled: Whether the event is from backfill - origin: The host the event originates from. - """ - # For new (non-backfilled and non-outlier) events we check if the event - # passes auth based on the current state. If it doesn't then we - # "soft-fail" the event. - if backfilled or event.internal_metadata.is_outlier(): - return - - extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id) - extrem_ids = set(extrem_ids_list) - prev_event_ids = set(event.prev_event_ids()) - - if extrem_ids == prev_event_ids: - # If they're the same then the current state is the same as the - # state at the event, so no point rechecking auth for soft fail. - return - - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - # Calculate the "current state". - if state is not None: - # If we're explicitly given the state then we won't have all the - # prev events, and so we have a gap in the graph. In this case - # we want to be a little careful as we might have been down for - # a while and have an incorrect view of the current state, - # however we still want to do checks as gaps are easy to - # maliciously manufacture. - # - # So we use a "current state" that is actually a state - # resolution across the current forward extremities and the - # given state at the event. This should correctly handle cases - # like bans, especially with state res v2. - - state_sets_d = await self.state_store.get_state_groups( - event.room_id, extrem_ids - ) - state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) - state_sets.append(state) - current_states = await self.state_handler.resolve_events( - room_version, state_sets, event - ) - current_state_ids: StateMap[str] = { - k: e.event_id for k, e in current_states.items() - } - else: - current_state_ids = await self.state_handler.get_current_state_ids( - event.room_id, latest_event_ids=extrem_ids - ) - - logger.debug( - "Doing soft-fail check for %s: state %s", - event.event_id, - current_state_ids, - ) - - # Now check if event pass auth against said current state - auth_types = auth_types_for_event(room_version_obj, event) - current_state_ids_list = [ - e for k, e in current_state_ids.items() if k in auth_types - ] - - auth_events_map = await self.store.get_events(current_state_ids_list) - current_auth_events = { - (e.type, e.state_key): e for e in auth_events_map.values() - } - - try: - event_auth.check(room_version_obj, event, auth_events=current_auth_events) - except AuthError as e: - logger.warning( - "Soft-failing %r (from %s) because %s", - event, - e, - origin, - extra={ - "room_id": event.room_id, - "mxid": event.sender, - "hs": origin, - }, - ) - soft_failed_event_counter.inc() - event.internal_metadata.soft_failed = True - async def on_get_missing_events( self, origin: str, @@ -2542,334 +1220,6 @@ class FederationHandler(BaseHandler): return missing_events - async def _check_event_auth( - self, - origin: str, - event: EventBase, - context: EventContext, - state: Optional[Iterable[EventBase]] = None, - claimed_auth_event_map: Optional[StateMap[EventBase]] = None, - backfilled: bool = False, - ) -> EventContext: - """ - Checks whether an event should be rejected (for failing auth checks). - - Args: - origin: The host the event originates from. - event: The event itself. - context: - The event context. - - state: - The state events used to check the event for soft-fail. If this is - not provided the current state events will be used. - - claimed_auth_event_map: - A map of (type, state_key) => event for the event's claimed auth_events. - Possibly incomplete, and possibly including events that are not yet - persisted, or authed, or in the right room. - - Only populated where we may not already have persisted these events - - for example, when populating outliers, or the state for a backwards - extremity. - - backfilled: True if the event was backfilled. - - Returns: - The updated context object. - """ - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - if claimed_auth_event_map: - # if we have a copy of the auth events from the event, use that as the - # basis for auth. - auth_events = claimed_auth_event_map - else: - # otherwise, we calculate what the auth events *should* be, and use that - prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self._event_auth_handler.compute_auth_events( - event, prev_state_ids, for_verification=True - ) - auth_events_x = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} - - try: - ( - context, - auth_events_for_auth, - ) = await self._update_auth_events_and_context_for_auth( - origin, event, context, auth_events - ) - except Exception: - # We don't really mind if the above fails, so lets not fail - # processing if it does. However, it really shouldn't fail so - # let's still log as an exception since we'll still want to fix - # any bugs. - logger.exception( - "Failed to double check auth events for %s with remote. " - "Ignoring failure and continuing processing of event.", - event.event_id, - ) - auth_events_for_auth = auth_events - - try: - event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) - except AuthError as e: - logger.warning("Failed auth resolution for %r because %s", event, e) - context.rejected = RejectedReason.AUTH_ERROR - - if not context.rejected: - await self._check_for_soft_fail(event, state, backfilled, origin=origin) - - if event.type == EventTypes.GuestAccess and not context.rejected: - await self.maybe_kick_guest_users(event) - - # If we are going to send this event over federation we precaclculate - # the joined hosts. - if event.internal_metadata.get_send_on_behalf_of(): - await self.event_creation_handler.cache_joined_hosts_for_event( - event, context - ) - - return context - - async def _update_auth_events_and_context_for_auth( - self, - origin: str, - event: EventBase, - context: EventContext, - input_auth_events: StateMap[EventBase], - ) -> Tuple[EventContext, StateMap[EventBase]]: - """Helper for _check_event_auth. See there for docs. - - Checks whether a given event has the expected auth events. If it - doesn't then we talk to the remote server to compare state to see if - we can come to a consensus (e.g. if one server missed some valid - state). - - This attempts to resolve any potential divergence of state between - servers, but is not essential and so failures should not block further - processing of the event. - - Args: - origin: - event: - context: - - input_auth_events: - Map from (event_type, state_key) to event - - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. - - Returns: - updated context, updated auth event map - """ - # take a copy of input_auth_events before we modify it. - auth_events: MutableStateMap[EventBase] = dict(input_auth_events) - - event_auth_events = set(event.auth_event_ids()) - - # missing_auth is the set of the event's auth_events which we don't yet have - # in auth_events. - missing_auth = event_auth_events.difference( - e.event_id for e in auth_events.values() - ) - - # if we have missing events, we need to fetch those events from somewhere. - # - # we start by checking if they are in the store, and then try calling /event_auth/. - if missing_auth: - have_events = await self.store.have_seen_events(event.room_id, missing_auth) - logger.debug("Events %s are in the store", have_events) - missing_auth.difference_update(have_events) - - if missing_auth: - # If we don't have all the auth events, we need to get them. - logger.info("auth_events contains unknown events: %s", missing_auth) - try: - try: - remote_auth_chain = await self.federation_client.get_event_auth( - origin, event.room_id, event.event_id - ) - except RequestSendFailed as e1: - # The other side isn't around or doesn't implement the - # endpoint, so lets just bail out. - logger.info("Failed to get event auth from remote: %s", e1) - return context, auth_events - - seen_remotes = await self.store.have_seen_events( - event.room_id, [e.event_id for e in remote_auth_chain] - ) - - for e in remote_auth_chain: - if e.event_id in seen_remotes: - continue - - if e.event_id == event.event_id: - continue - - try: - auth_ids = e.auth_event_ids() - auth = { - (e.type, e.state_key): e - for e in remote_auth_chain - if e.event_id in auth_ids or e.type == EventTypes.Create - } - e.internal_metadata.outlier = True - - logger.debug( - "_check_event_auth %s missing_auth: %s", - event.event_id, - e.event_id, - ) - missing_auth_event_context = ( - await self.state_handler.compute_event_context(e) - ) - await self._auth_and_persist_event( - origin, - e, - missing_auth_event_context, - claimed_auth_event_map=auth, - ) - - if e.event_id in event_auth_events: - auth_events[(e.type, e.state_key)] = e - except AuthError: - pass - - except Exception: - logger.exception("Failed to get auth chain") - - if event.internal_metadata.is_outlier(): - # XXX: given that, for an outlier, we'll be working with the - # event's *claimed* auth events rather than those we calculated: - # (a) is there any point in this test, since different_auth below will - # obviously be empty - # (b) alternatively, why don't we do it earlier? - logger.info("Skipping auth_event fetch for outlier") - return context, auth_events - - different_auth = event_auth_events.difference( - e.event_id for e in auth_events.values() - ) - - if not different_auth: - return context, auth_events - - logger.info( - "auth_events refers to events which are not in our calculated auth " - "chain: %s", - different_auth, - ) - - # XXX: currently this checks for redactions but I'm not convinced that is - # necessary? - different_events = await self.store.get_events_as_list(different_auth) - - for d in different_events: - if d.room_id != event.room_id: - logger.warning( - "Event %s refers to auth_event %s which is in a different room", - event.event_id, - d.event_id, - ) - - # don't attempt to resolve the claimed auth events against our own - # in this case: just use our own auth events. - # - # XXX: should we reject the event in this case? It feels like we should, - # but then shouldn't we also do so if we've failed to fetch any of the - # auth events? - return context, auth_events - - # now we state-resolve between our own idea of the auth events, and the remote's - # idea of them. - - local_state = auth_events.values() - remote_auth_events = dict(auth_events) - remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) - remote_state = remote_auth_events.values() - - room_version = await self.store.get_room_version_id(event.room_id) - new_state = await self.state_handler.resolve_events( - room_version, (local_state, remote_state), event - ) - - logger.info( - "After state res: updating auth_events with new state %s", - { - (d.type, d.state_key): d.event_id - for d in new_state.values() - if auth_events.get((d.type, d.state_key)) != d - }, - ) - - auth_events.update(new_state) - - context = await self._update_context_for_auth_events( - event, context, auth_events - ) - - return context, auth_events - - async def _update_context_for_auth_events( - self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] - ) -> EventContext: - """Update the state_ids in an event context after auth event resolution, - storing the changes as a new state group. - - Args: - event: The event we're handling the context for - - context: initial event context - - auth_events: Events to update in the event context. - - Returns: - new event context - """ - # exclude the state key of the new event from the current_state in the context. - if event.is_state(): - event_key: Optional[Tuple[str, str]] = (event.type, event.state_key) - else: - event_key = None - state_updates = { - k: a.event_id for k, a in auth_events.items() if k != event_key - } - - current_state_ids = await context.get_current_state_ids() - current_state_ids = dict(current_state_ids) # type: ignore - - current_state_ids.update(state_updates) - - prev_state_ids = await context.get_prev_state_ids() - prev_state_ids = dict(prev_state_ids) - - prev_state_ids.update({k: a.event_id for k, a in auth_events.items()}) - - # create a new state group as a delta from the existing one. - prev_group = context.state_group - state_group = await self.state_store.store_state_group( - event.event_id, - event.room_id, - prev_group=prev_group, - delta_ids=state_updates, - current_state_ids=current_state_ids, - ) - - return EventContext.with_state( - state_group=state_group, - state_group_before_event=context.state_group_before_event, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, - prev_group=prev_group, - delta_ids=state_updates, - ) - async def construct_auth_difference( self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase] ) -> Dict: @@ -3256,99 +1606,6 @@ class FederationHandler(BaseHandler): if "valid" not in response or not response["valid"]: raise AuthError(403, "Third party certificate was invalid") - async def persist_events_and_notify( - self, - room_id: str, - event_and_contexts: Sequence[Tuple[EventBase, EventContext]], - backfilled: bool = False, - ) -> int: - """Persists events and tells the notifier/pushers about them, if - necessary. - - Args: - room_id: The room ID of events being persisted. - event_and_contexts: Sequence of events with their associated - context that should be persisted. All events must belong to - the same room. - backfilled: Whether these events are a result of - backfilling or not - - Returns: - The stream ID after which all events have been persisted. - """ - if not event_and_contexts: - return self.store.get_current_events_token() - - instance = self.config.worker.events_shard_config.get_instance(room_id) - if instance != self._instance_name: - # Limit the number of events sent over replication. We choose 200 - # here as that is what we default to in `max_request_body_size(..)` - for batch in batch_iter(event_and_contexts, 200): - result = await self._send_events( - instance_name=instance, - store=self.store, - room_id=room_id, - event_and_contexts=batch, - backfilled=backfilled, - ) - return result["max_stream_id"] - else: - assert self.storage.persistence - - # Note that this returns the events that were persisted, which may not be - # the same as were passed in if some were deduplicated due to transaction IDs. - events, max_stream_token = await self.storage.persistence.persist_events( - event_and_contexts, backfilled=backfilled - ) - - if self._ephemeral_messages_enabled: - for event in events: - # If there's an expiry timestamp on the event, schedule its expiry. - self._message_handler.maybe_schedule_expiry(event) - - if not backfilled: # Never notify for backfilled events - for event in events: - await self._notify_persisted_event(event, max_stream_token) - - return max_stream_token.stream - - async def _notify_persisted_event( - self, event: EventBase, max_stream_token: RoomStreamToken - ) -> None: - """Checks to see if notifier/pushers should be notified about the - event or not. - - Args: - event: - max_stream_id: The max_stream_id returned by persist_events - """ - - extra_users = [] - if event.type == EventTypes.Member: - target_user_id = event.state_key - - # We notify for memberships if its an invite for one of our - # users - if event.internal_metadata.is_outlier(): - if event.membership != Membership.INVITE: - if not self.is_mine_id(target_user_id): - return - - target_user = UserID.from_string(target_user_id) - extra_users.append(target_user) - elif event.internal_metadata.is_outlier(): - return - - # the event has been persisted so it should have a stream ordering. - assert event.internal_metadata.stream_ordering - - event_pos = PersistedEventPosition( - self._instance_name, event.internal_metadata.stream_ordering - ) - self.notifier.on_new_room_event( - event, event_pos, max_stream_token, extra_users=extra_users - ) - async def _clean_room_for_join(self, room_id: str) -> None: """Called to clean up any data in DB for a given room, ready for the server to join the room. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py new file mode 100644 index 0000000000..9f055f00cf --- /dev/null +++ b/synapse/handlers/federation_event.py @@ -0,0 +1,1825 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from http import HTTPStatus +from typing import ( + TYPE_CHECKING, + Collection, + Container, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, +) + +import attr +from prometheus_client import Counter + +from twisted.internet import defer + +from synapse import event_auth +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RejectedReason, + RoomEncryptionAlgorithms, +) +from synapse.api.errors import ( + AuthError, + Codes, + FederationError, + HttpResponseException, + RequestSendFailed, + SynapseError, +) +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.event_auth import auth_types_for_event +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.federation.federation_client import InvalidResponseError +from synapse.handlers._base import BaseHandler +from synapse.logging.context import ( + make_deferred_yieldable, + nested_logging_context, + run_in_background, +) +from synapse.logging.utils import log_function +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet +from synapse.replication.http.federation import ( + ReplicationFederationSendEventsRestServlet, +) +from synapse.state import StateResolutionStore +from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.types import ( + MutableStateMap, + PersistedEventPosition, + RoomStreamToken, + StateMap, + UserID, + get_domain_from_id, +) +from synapse.util.async_helpers import Linearizer, concurrently_execute +from synapse.util.iterutils import batch_iter +from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import shortstr + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + +soft_failed_event_counter = Counter( + "synapse_federation_soft_failed_events_total", + "Events received over federation that we marked as soft_failed", +) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _NewEventInfo: + """Holds information about a received event, ready for passing to _auth_and_persist_events + + Attributes: + event: the received event + + claimed_auth_event_map: a map of (type, state_key) => event for the event's + claimed auth_events. + + This can include events which have not yet been persisted, in the case that + we are backfilling a batch of events. + + Note: May be incomplete: if we were unable to find all of the claimed auth + events. Also, treat the contents with caution: the events might also have + been rejected, might not yet have been authorized themselves, or they might + be in the wrong room. + + """ + + event: EventBase + claimed_auth_event_map: StateMap[EventBase] + + +class FederationEventHandler(BaseHandler): + """Handles events that originated from federation. + + Responsible for handing incoming events and passing them on to the rest + of the homeserver (including auth and state conflict resolutions) + """ + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state + + self.state_handler = hs.get_state_handler() + self.event_creation_handler = hs.get_event_creation_handler() + self._event_auth_handler = hs.get_event_auth_handler() + self._message_handler = hs.get_message_handler() + self.action_generator = hs.get_action_generator() + self._state_resolution_handler = hs.get_state_resolution_handler() + + self.federation_client = hs.get_federation_client() + self.third_party_event_rules = hs.get_third_party_event_rules() + + self.is_mine_id = hs.is_mine_id + self._instance_name = hs.get_instance_name() + + self.config = hs.config + self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages + + self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) + if hs.config.worker_app: + self._user_device_resync = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) + ) + else: + self._device_list_updater = hs.get_device_handler().device_list_updater + + # When joining a room we need to queue any events for that room up. + # For each room, a list of (pdu, origin) tuples. + # TODO: replace this with something more elegant, probably based around the + # federation event staging area. + self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {} + + self._room_pdu_linearizer = Linearizer("fed_room_pdu") + + async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: + """Process a PDU received via a federation /send/ transaction + + Args: + origin: server which initiated the /send/ transaction. Will + be used to fetch missing events or state. + pdu: received PDU + """ + + room_id = pdu.room_id + event_id = pdu.event_id + + # We reprocess pdus when we have seen them only as outliers + existing = await self.store.get_event( + event_id, allow_none=True, allow_rejected=True + ) + + # FIXME: Currently we fetch an event again when we already have it + # if it has been marked as an outlier. + if existing: + if not existing.internal_metadata.is_outlier(): + logger.info( + "Ignoring received event %s which we have already seen", event_id + ) + return + if pdu.internal_metadata.is_outlier(): + logger.info( + "Ignoring received outlier %s which we already have as an outlier", + event_id, + ) + return + logger.info("De-outliering event %s", event_id) + + # do some initial sanity-checking of the event. In particular, make + # sure it doesn't have hundreds of prev_events or auth_events, which + # could cause a huge state resolution or cascade of event fetches. + try: + self._sanity_check_event(pdu) + except SynapseError as err: + logger.warning("Received event failed sanity checks") + raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) + + # If we are currently in the process of joining this room, then we + # queue up events for later processing. + if room_id in self.room_queues: + logger.info( + "Queuing PDU from %s for now: join in progress", + origin, + ) + self.room_queues[room_id].append((pdu, origin)) + return + + # If we're not in the room just ditch the event entirely. This is + # probably an old server that has come back and thinks we're still in + # the room (or we've been rejoined to the room by a state reset). + # + # Note that if we were never in the room then we would have already + # dropped the event, since we wouldn't know the room version. + is_in_room = await self._event_auth_handler.check_host_in_room( + room_id, self.server_name + ) + if not is_in_room: + logger.info( + "Ignoring PDU from %s as we're not in the room", + origin, + ) + return None + + # Check that the event passes auth based on the state at the event. This is + # done for events that are to be added to the timeline (non-outliers). + # + # Get missing pdus if necessary: + # - Fetching any missing prev events to fill in gaps in the graph + # - Fetching state if we have a hole in the graph + if not pdu.internal_metadata.is_outlier(): + prevs = set(pdu.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if missing_prevs: + # We only backfill backwards to the min depth. + min_depth = await self.get_min_depth_for_context(pdu.room_id) + logger.debug("min_depth: %d", min_depth) + + if min_depth is not None and pdu.depth > min_depth: + # If we're missing stuff, ensure we only fetch stuff one + # at a time. + logger.info( + "Acquiring room lock to fetch %d missing prev_events: %s", + len(missing_prevs), + shortstr(missing_prevs), + ) + with (await self._room_pdu_linearizer.queue(pdu.room_id)): + logger.info( + "Acquired room lock to fetch %d missing prev_events", + len(missing_prevs), + ) + + try: + await self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth + ) + except Exception as e: + raise Exception( + "Error fetching missing prev_events for %s: %s" + % (event_id, e) + ) from e + + # Update the set of things we've seen after trying to + # fetch the missing stuff + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if not missing_prevs: + logger.info("Found all missing prev_events") + + if missing_prevs: + # since this event was pushed to us, it is possible for it to + # become the only forward-extremity in the room, and we would then + # trust its state to be the state for the whole room. This is very + # bad. Further, if the event was pushed to us, there is no excuse + # for us not to have all the prev_events. (XXX: apart from + # min_depth?) + # + # We therefore reject any such events. + logger.warning( + "Rejecting: failed to fetch %d prev events: %s", + len(missing_prevs), + shortstr(missing_prevs), + ) + raise FederationError( + "ERROR", + 403, + ( + "Your server isn't divulging details about prev_events " + "referenced in this event." + ), + affected=pdu.event_id, + ) + + await self._process_received_pdu(origin, pdu, state=None) + + @log_function + async def on_send_membership_event( + self, origin: str, event: EventBase + ) -> Tuple[EventBase, EventContext]: + """ + We have received a join/leave/knock event for a room via send_join/leave/knock. + + Verify that event and send it into the room on the remote homeserver's behalf. + + This is quite similar to on_receive_pdu, with the following principal + differences: + * only membership events are permitted (and only events with + sender==state_key -- ie, no kicks or bans) + * *We* send out the event on behalf of the remote server. + * We enforce the membership restrictions of restricted rooms. + * Rejected events result in an exception rather than being stored. + + There are also other differences, however it is not clear if these are by + design or omission. In particular, we do not attempt to backfill any missing + prev_events. + + Args: + origin: The homeserver of the remote (joining/invited/knocking) user. + event: The member event that has been signed by the remote homeserver. + + Returns: + The event and context of the event after inserting it into the room graph. + + Raises: + SynapseError if the event is not accepted into the room + """ + logger.debug( + "on_send_membership_event: Got event: %s, signatures: %s", + event.event_id, + event.signatures, + ) + + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got send_membership request for user %r from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + + if event.sender != event.state_key: + raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON) + + assert not event.internal_metadata.outlier + + # Send this event on behalf of the other server. + # + # The remote server isn't a full participant in the room at this point, so + # may not have an up-to-date list of the other homeservers participating in + # the room, so we send it on their behalf. + event.internal_metadata.send_on_behalf_of = origin + + context = await self.state_handler.compute_event_context(event) + context = await self._check_event_auth(origin, event, context) + if context.rejected: + raise SynapseError( + 403, f"{event.membership} event was rejected", Codes.FORBIDDEN + ) + + # for joins, we need to check the restrictions of restricted rooms + if event.membership == Membership.JOIN: + await self.check_join_restrictions(context, event) + + # for knock events, we run the third-party event rules. It's not entirely clear + # why we don't do this for other sorts of membership events. + if event.membership == Membership.KNOCK: + event_allowed, _ = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.info("Sending of knock %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + # all looks good, we can persist the event. + await self._run_push_actions_and_persist_event(event, context) + return event, context + + async def check_join_restrictions( + self, context: EventContext, event: EventBase + ) -> None: + """Check that restrictions in restricted join rules are matched + + Called when we receive a join event via send_join. + + Raises an auth error if the restrictions are not matched. + """ + prev_state_ids = await context.get_prev_state_ids() + + # Check if the user is already in the room or invited to the room. + user_id = event.state_key + prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) + prev_member_event = None + if prev_member_event_id: + prev_member_event = await self.store.get_event(prev_member_event_id) + + # Check if the member should be allowed access via membership in a space. + await self._event_auth_handler.check_restricted_join_rules( + prev_state_ids, + event.room_version, + user_id, + prev_member_event, + ) + + @log_function + async def backfill( + self, dest: str, room_id: str, limit: int, extremities: List[str] + ) -> None: + """Trigger a backfill request to `dest` for the given `room_id` + + This will attempt to get more events from the remote. If the other side + has no new events to offer, this will return an empty list. + + As the events are received, we check their signatures, and also do some + sanity-checking on them. If any of the backfilled events are invalid, + this method throws a SynapseError. + + We might also raise an InvalidResponseError if the response from the remote + server is just bogus. + + TODO: make this more useful to distinguish failures of the remote + server from invalid events (there is probably no point in trying to + re-fetch invalid events from every other HS in the room.) + """ + if dest == self.server_name: + raise SynapseError(400, "Can't backfill from self.") + + events = await self.federation_client.backfill( + dest, room_id, limit=limit, extremities=extremities + ) + + if not events: + return + + # if there are any events in the wrong room, the remote server is buggy and + # should not be trusted. + for ev in events: + if ev.room_id != room_id: + raise InvalidResponseError( + f"Remote server {dest} returned event {ev.event_id} which is in " + f"room {ev.room_id}, when we were backfilling in {room_id}" + ) + + await self._process_pulled_events(dest, events, backfilled=True) + + async def _get_missing_events_for_pdu( + self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int + ) -> None: + """ + Args: + origin: Origin of the pdu. Will be called to get the missing events + pdu: received pdu + prevs: List of event ids which we are missing + min_depth: Minimum depth of events to return. + """ + + room_id = pdu.room_id + event_id = pdu.event_id + + seen = await self.store.have_events_in_timeline(prevs) + + if not prevs - seen: + return + + latest_list = await self.store.get_latest_event_ids_in_room(room_id) + + # We add the prev events that we have seen to the latest + # list to ensure the remote server doesn't give them to us + latest = set(latest_list) + latest |= seen + + logger.info( + "Requesting missing events between %s and %s", + shortstr(latest), + event_id, + ) + + # XXX: we set timeout to 10s to help workaround + # https://github.com/matrix-org/synapse/issues/1733. + # The reason is to avoid holding the linearizer lock + # whilst processing inbound /send transactions, causing + # FDs to stack up and block other inbound transactions + # which empirically can currently take up to 30 minutes. + # + # N.B. this explicitly disables retry attempts. + # + # N.B. this also increases our chances of falling back to + # fetching fresh state for the room if the missing event + # can't be found, which slightly reduces our security. + # it may also increase our DAG extremity count for the room, + # causing additional state resolution? See #1760. + # However, fetching state doesn't hold the linearizer lock + # apparently. + # + # see https://github.com/matrix-org/synapse/pull/1744 + # + # ---- + # + # Update richvdh 2018/09/18: There are a number of problems with timing this + # request out aggressively on the client side: + # + # - it plays badly with the server-side rate-limiter, which starts tarpitting you + # if you send too many requests at once, so you end up with the server carefully + # working through the backlog of your requests, which you have already timed + # out. + # + # - for this request in particular, we now (as of + # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the + # server can't produce a plausible-looking set of prev_events - so we becone + # much more likely to reject the event. + # + # - contrary to what it says above, we do *not* fall back to fetching fresh state + # for the room if get_missing_events times out. Rather, we give up processing + # the PDU whose prevs we are missing, which then makes it much more likely that + # we'll end up back here for the *next* PDU in the list, which exacerbates the + # problem. + # + # - the aggressive 10s timeout was introduced to deal with incoming federation + # requests taking 8 hours to process. It's not entirely clear why that was going + # on; certainly there were other issues causing traffic storms which are now + # resolved, and I think in any case we may be more sensible about our locking + # now. We're *certainly* more sensible about our logging. + # + # All that said: Let's try increasing the timeout to 60s and see what happens. + + try: + missing_events = await self.federation_client.get_missing_events( + origin, + room_id, + earliest_events_ids=list(latest), + latest_events=[pdu], + limit=10, + min_depth=min_depth, + timeout=60000, + ) + except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e: + # We failed to get the missing events, but since we need to handle + # the case of `get_missing_events` not returning the necessary + # events anyway, it is safe to simply log the error and continue. + logger.warning("Failed to get prev_events: %s", e) + return + + logger.info("Got %d prev_events", len(missing_events)) + await self._process_pulled_events(origin, missing_events, backfilled=False) + + async def _process_pulled_events( + self, origin: str, events: Iterable[EventBase], backfilled: bool + ) -> None: + """Process a batch of events we have pulled from a remote server + + Pulls in any events required to auth the events, persists the received events, + and notifies clients, if appropriate. + + Assumes the events have already had their signatures and hashes checked. + + Params: + origin: The server we received these events from + events: The received events. + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) + """ + + # We want to sort these by depth so we process them and + # tell clients about them in order. + sorted_events = sorted(events, key=lambda x: x.depth) + + for ev in sorted_events: + with nested_logging_context(ev.event_id): + await self._process_pulled_event(origin, ev, backfilled=backfilled) + + async def _process_pulled_event( + self, origin: str, event: EventBase, backfilled: bool + ) -> None: + """Process a single event that we have pulled from a remote server + + Pulls in any events required to auth the event, persists the received event, + and notifies clients, if appropriate. + + Assumes the event has already had its signatures and hashes checked. + + This is somewhat equivalent to on_receive_pdu, but applies somewhat different + logic in the case that we are missing prev_events (in particular, it just + requests the state at that point, rather than triggering a get_missing_events) - + so is appropriate when we have pulled the event from a remote server, rather + than having it pushed to us. + + Params: + origin: The server we received this event from + events: The received event + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) + """ + logger.info("Processing pulled event %s", event) + + # these should not be outliers. + assert not event.internal_metadata.is_outlier() + + event_id = event.event_id + + existing = await self.store.get_event( + event_id, allow_none=True, allow_rejected=True + ) + if existing: + if not existing.internal_metadata.is_outlier(): + logger.info( + "Ignoring received event %s which we have already seen", + event_id, + ) + return + logger.info("De-outliering event %s", event_id) + + try: + self._sanity_check_event(event) + except SynapseError as err: + logger.warning("Event %s failed sanity check: %s", event_id, err) + return + + try: + state = await self._resolve_state_at_missing_prevs(origin, event) + await self._process_received_pdu( + origin, event, state=state, backfilled=backfilled + ) + except FederationError as e: + if e.code == 403: + logger.warning("Pulled event %s failed history check.", event_id) + else: + raise + + async def _resolve_state_at_missing_prevs( + self, dest: str, event: EventBase + ) -> Optional[Iterable[EventBase]]: + """Calculate the state at an event with missing prev_events. + + This is used when we have pulled a batch of events from a remote server, and + still don't have all the prev_events. + + If we already have all the prev_events for `event`, this method does nothing. + + Otherwise, the missing prevs become new backwards extremities, and we fall back + to asking the remote server for the state after each missing `prev_event`, + and resolving across them. + + That's ok provided we then resolve the state against other bits of the DAG + before using it - in other words, that the received event `event` is not going + to become the only forwards_extremity in the room (which will ensure that you + can't just take over a room by sending an event, withholding its prev_events, + and declaring yourself to be an admin in the subsequent state request). + + In other words: we should only call this method if `event` has been *pulled* + as part of a batch of missing prev events, or similar. + + Params: + dest: the remote server to ask for state at the missing prevs. Typically, + this will be the server we got `event` from. + event: an event to check for missing prevs. + + Returns: + if we already had all the prev events, `None`. Otherwise, returns a list of + the events in the state at `event`. + """ + room_id = event.room_id + event_id = event.event_id + + prevs = set(event.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if not missing_prevs: + return None + + logger.info( + "Event %s is missing prev_events %s: calculating state for a " + "backwards extremity", + event_id, + shortstr(missing_prevs), + ) + # Calculate the state after each of the previous events, and + # resolve them to find the correct state at the current event. + event_map = {event_id: event} + try: + # Get the state of the events we know about + ours = await self.state_store.get_state_groups_ids(room_id, seen) + + # state_maps is a list of mappings from (type, state_key) to event_id + state_maps: List[StateMap[str]] = list(ours.values()) + + # we don't need this any more, let's delete it. + del ours + + # Ask the remote server for the states we don't + # know about + for p in missing_prevs: + logger.info("Requesting state after missing prev_event %s", p) + + with nested_logging_context(p): + # note that if any of the missing prevs share missing state or + # auth events, the requests to fetch those events are deduped + # by the get_pdu_cache in federation_client. + remote_state = await self._get_state_after_missing_prev_event( + dest, room_id, p + ) + + remote_state_map = { + (x.type, x.state_key): x.event_id for x in remote_state + } + state_maps.append(remote_state_map) + + for x in remote_state: + event_map[x.event_id] = x + + room_version = await self.store.get_room_version_id(room_id) + state_map = await self._state_resolution_handler.resolve_events_with_store( + room_id, + room_version, + state_maps, + event_map, + state_res_store=StateResolutionStore(self.store), + ) + + # We need to give _process_received_pdu the actual state events + # rather than event ids, so generate that now. + + # First though we need to fetch all the events that are in + # state_map, so we can build up the state below. + evs = await self.store.get_events( + list(state_map.values()), + get_prev_content=False, + redact_behaviour=EventRedactBehaviour.AS_IS, + ) + event_map.update(evs) + + state = [event_map[e] for e in state_map.values()] + except Exception: + logger.warning( + "Error attempting to resolve state at missing prev_events", + exc_info=True, + ) + raise FederationError( + "ERROR", + 403, + "We can't get valid state history.", + affected=event_id, + ) + return state + + async def _get_state_after_missing_prev_event( + self, + destination: str, + room_id: str, + event_id: str, + ) -> List[EventBase]: + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination: The remote homeserver to query for the state. + room_id: The id of the room we're interested in. + event_id: The id of the event we want the state at. + + Returns: + A list of events in the state, including the event itself + """ + ( + state_event_ids, + auth_event_ids, + ) = await self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) + + logger.debug( + "state_ids returned %i state events, %i auth events", + len(state_event_ids), + len(auth_event_ids), + ) + + # start by just trying to fetch the events from the store + desired_events = set(state_event_ids) + desired_events.add(event_id) + logger.debug("Fetching %i events from cache/store", len(desired_events)) + fetched_events = await self.store.get_events( + desired_events, allow_rejected=True + ) + + missing_desired_events = desired_events - fetched_events.keys() + logger.debug( + "We are missing %i events (got %i)", + len(missing_desired_events), + len(fetched_events), + ) + + # We probably won't need most of the auth events, so let's just check which + # we have for now, rather than thrashing the event cache with them all + # unnecessarily. + + # TODO: we probably won't actually need all of the auth events, since we + # already have a bunch of the state events. It would be nice if the + # federation api gave us a way of finding out which we actually need. + + missing_auth_events = set(auth_event_ids) - fetched_events.keys() + missing_auth_events.difference_update( + await self.store.have_seen_events(room_id, missing_auth_events) + ) + logger.debug("We are also missing %i auth events", len(missing_auth_events)) + + missing_events = missing_desired_events | missing_auth_events + logger.debug("Fetching %i events from remote", len(missing_events)) + await self._get_events_and_persist( + destination=destination, room_id=room_id, events=missing_events + ) + + # we need to make sure we re-load from the database to get the rejected + # state correct. + fetched_events.update( + await self.store.get_events(missing_desired_events, allow_rejected=True) + ) + + # check for events which were in the wrong room. + # + # this can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + + bad_events = [ + (event_id, event.room_id) + for event_id, event in fetched_events.items() + if event.room_id != room_id + ] + + for bad_event_id, bad_room_id in bad_events: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + bad_event_id, + bad_room_id, + room_id, + ) + + del fetched_events[bad_event_id] + + # if we couldn't get the prev event in question, that's a problem. + remote_event = fetched_events.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + + # missing state at that event is a warning, not a blocker + # XXX: this doesn't sound right? it means that we'll end up with incomplete + # state. + failed_to_fetch = desired_events - fetched_events.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state events for %s %s", + event_id, + failed_to_fetch, + ) + + remote_state = [ + fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events + ] + + if remote_event.is_state() and remote_event.rejected_reason is None: + remote_state.append(remote_event) + + return remote_state + + async def _process_received_pdu( + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]], + backfilled: bool = False, + ) -> None: + """Called when we have a new pdu. We need to do auth checks and put it + through the StateHandler. + + Args: + origin: server sending the event + + event: event to be persisted + + state: Normally None, but if we are handling a gap in the graph + (ie, we are missing one or more prev_events), the resolved state at the + event + + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) + """ + logger.debug("Processing event: %s", event) + + try: + context = await self.state_handler.compute_event_context( + event, old_state=state + ) + await self._auth_and_persist_event( + origin, event, context, state=state, backfilled=backfilled + ) + except AuthError as e: + raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) + + if backfilled: + return + + # For encrypted messages we check that we know about the sending device, + # if we don't then we mark the device cache for that user as stale. + if event.type == EventTypes.Encrypted: + device_id = event.content.get("device_id") + sender_key = event.content.get("sender_key") + + cached_devices = await self.store.get_cached_devices_for_user(event.sender) + + resync = False # Whether we should resync device lists. + + device = None + if device_id is not None: + device = cached_devices.get(device_id) + if device is None: + logger.info( + "Received event from remote device not in our cache: %s %s", + event.sender, + device_id, + ) + resync = True + + # We also check if the `sender_key` matches what we expect. + if sender_key is not None: + # Figure out what sender key we're expecting. If we know the + # device and recognize the algorithm then we can work out the + # exact key to expect. Otherwise check it matches any key we + # have for that device. + + current_keys: Container[str] = [] + + if device: + keys = device.get("keys", {}).get("keys", {}) + + if ( + event.content.get("algorithm") + == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2 + ): + # For this algorithm we expect a curve25519 key. + key_name = "curve25519:%s" % (device_id,) + current_keys = [keys.get(key_name)] + else: + # We don't know understand the algorithm, so we just + # check it matches a key for the device. + current_keys = keys.values() + elif device_id: + # We don't have any keys for the device ID. + pass + else: + # The event didn't include a device ID, so we just look for + # keys across all devices. + current_keys = [ + key + for device in cached_devices.values() + for key in device.get("keys", {}).get("keys", {}).values() + ] + + # We now check that the sender key matches (one of) the expected + # keys. + if sender_key not in current_keys: + logger.info( + "Received event from remote device with unexpected sender key: %s %s: %s", + event.sender, + device_id or "", + sender_key, + ) + resync = True + + if resync: + run_as_background_process( + "resync_device_due_to_pdu", + self._resync_device, + event.sender, + ) + + await self._handle_marker_event(origin, event) + + async def _resync_device(self, sender: str) -> None: + """We have detected that the device list for the given user may be out + of sync, so we try and resync them. + """ + + try: + await self.store.mark_remote_user_device_cache_as_stale(sender) + + # Immediately attempt a resync in the background + if self.config.worker_app: + await self._user_device_resync(user_id=sender) + else: + await self._device_list_updater.user_device_resync(sender) + except Exception: + logger.exception("Failed to resync device for %s", sender) + + async def _handle_marker_event(self, origin: str, marker_event: EventBase): + """Handles backfilling the insertion event when we receive a marker + event that points to one. + + Args: + origin: Origin of the event. Will be called to get the insertion event + marker_event: The event to process + """ + + if marker_event.type != EventTypes.MSC2716_MARKER: + # Not a marker event + return + + if marker_event.rejected_reason is not None: + # Rejected event + return + + # Skip processing a marker event if the room version doesn't + # support it. + room_version = await self.store.get_room_version(marker_event.room_id) + if not room_version.msc2716_historical: + return + + logger.debug("_handle_marker_event: received %s", marker_event) + + insertion_event_id = marker_event.content.get( + EventContentFields.MSC2716_MARKER_INSERTION + ) + + if insertion_event_id is None: + # Nothing to retrieve then (invalid marker) + return + + logger.debug( + "_handle_marker_event: backfilling insertion event %s", insertion_event_id + ) + + await self._get_events_and_persist( + origin, + marker_event.room_id, + [insertion_event_id], + ) + + insertion_event = await self.store.get_event( + insertion_event_id, allow_none=True + ) + if insertion_event is None: + logger.warning( + "_handle_marker_event: server %s didn't return insertion event %s for marker %s", + origin, + insertion_event_id, + marker_event.event_id, + ) + return + + logger.debug( + "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s", + insertion_event, + marker_event, + ) + + await self.store.insert_insertion_extremity( + insertion_event_id, marker_event.room_id + ) + + logger.debug( + "_handle_marker_event: insertion extremity added for %s from marker event %s", + insertion_event, + marker_event, + ) + + async def _get_events_and_persist( + self, destination: str, room_id: str, events: Iterable[str] + ) -> None: + """Fetch the given events from a server, and persist them as outliers. + + This function *does not* recursively get missing auth events of the + newly fetched events. Callers must include in the `events` argument + any missing events from the auth chain. + + Logs a warning if we can't find the given event. + """ + + room_version = await self.store.get_room_version(room_id) + + event_map: Dict[str, EventBase] = {} + + async def get_event(event_id: str): + with nested_logging_context(event_id): + try: + event = await self.federation_client.get_pdu( + [destination], + event_id, + room_version, + outlier=True, + ) + if event is None: + logger.warning( + "Server %s didn't return event %s", + destination, + event_id, + ) + return + + event_map[event.event_id] = event + + except Exception as e: + logger.warning( + "Error fetching missing state/auth event %s: %s %s", + event_id, + type(e), + e, + ) + + await concurrently_execute(get_event, events, 5) + + # Make a map of auth events for each event. We do this after fetching + # all the events as some of the events' auth events will be in the list + # of requested events. + + auth_events = [ + aid + for event in event_map.values() + for aid in event.auth_event_ids() + if aid not in event_map + ] + persisted_events = await self.store.get_events( + auth_events, + allow_rejected=True, + ) + + event_infos = [] + for event in event_map.values(): + auth = {} + for auth_event_id in event.auth_event_ids(): + ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id) + if ae: + auth[(ae.type, ae.state_key)] = ae + else: + logger.info("Missing auth event %s", auth_event_id) + + event_infos.append(_NewEventInfo(event, auth)) + + if event_infos: + await self._auth_and_persist_events( + destination, + room_id, + event_infos, + ) + + async def _auth_and_persist_events( + self, + origin: str, + room_id: str, + event_infos: Collection[_NewEventInfo], + ) -> None: + """Creates the appropriate contexts and persists events. The events + should not depend on one another, e.g. this should be used to persist + a bunch of outliers, but not a chunk of individual events that depend + on each other for state calculations. + + Notifies about the events where appropriate. + """ + + if not event_infos: + return + + async def prep(ev_info: _NewEventInfo): + event = ev_info.event + with nested_logging_context(suffix=event.event_id): + res = await self.state_handler.compute_event_context(event) + res = await self._check_event_auth( + origin, + event, + res, + claimed_auth_event_map=ev_info.claimed_auth_event_map, + ) + return res + + contexts = await make_deferred_yieldable( + defer.gatherResults( + [run_in_background(prep, ev_info) for ev_info in event_infos], + consumeErrors=True, + ) + ) + + await self.persist_events_and_notify( + room_id, + [ + (ev_info.event, context) + for ev_info, context in zip(event_infos, contexts) + ], + ) + + async def _auth_and_persist_event( + self, + origin: str, + event: EventBase, + context: EventContext, + state: Optional[Iterable[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, + backfilled: bool = False, + ) -> None: + """ + Process an event by performing auth checks and then persisting to the database. + + Args: + origin: The host the event originates from. + event: The event itself. + context: + The event context. + + state: + The state events used to check the event for soft-fail. If this is + not provided the current state events will be used. + + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. + + Only populated where we may not already have persisted these events - + for example, when populating outliers. + + backfilled: True if the event was backfilled. + """ + context = await self._check_event_auth( + origin, + event, + context, + state=state, + claimed_auth_event_map=claimed_auth_event_map, + backfilled=backfilled, + ) + + await self._run_push_actions_and_persist_event(event, context, backfilled) + + async def _check_event_auth( + self, + origin: str, + event: EventBase, + context: EventContext, + state: Optional[Iterable[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, + backfilled: bool = False, + ) -> EventContext: + """ + Checks whether an event should be rejected (for failing auth checks). + + Args: + origin: The host the event originates from. + event: The event itself. + context: + The event context. + + state: + The state events used to check the event for soft-fail. If this is + not provided the current state events will be used. + + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. + + Only populated where we may not already have persisted these events - + for example, when populating outliers, or the state for a backwards + extremity. + + backfilled: True if the event was backfilled. + + Returns: + The updated context object. + """ + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + if claimed_auth_event_map: + # if we have a copy of the auth events from the event, use that as the + # basis for auth. + auth_events = claimed_auth_event_map + else: + # otherwise, we calculate what the auth events *should* be, and use that + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self._event_auth_handler.compute_auth_events( + event, prev_state_ids, for_verification=True + ) + auth_events_x = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} + + try: + ( + context, + auth_events_for_auth, + ) = await self._update_auth_events_and_context_for_auth( + origin, event, context, auth_events + ) + except Exception: + # We don't really mind if the above fails, so lets not fail + # processing if it does. However, it really shouldn't fail so + # let's still log as an exception since we'll still want to fix + # any bugs. + logger.exception( + "Failed to double check auth events for %s with remote. " + "Ignoring failure and continuing processing of event.", + event.event_id, + ) + auth_events_for_auth = auth_events + + try: + event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) + except AuthError as e: + logger.warning("Failed auth resolution for %r because %s", event, e) + context.rejected = RejectedReason.AUTH_ERROR + + if not context.rejected: + await self._check_for_soft_fail(event, state, backfilled, origin=origin) + + if event.type == EventTypes.GuestAccess and not context.rejected: + await self.maybe_kick_guest_users(event) + + # If we are going to send this event over federation we precaclculate + # the joined hosts. + if event.internal_metadata.get_send_on_behalf_of(): + await self.event_creation_handler.cache_joined_hosts_for_event( + event, context + ) + + return context + + async def _check_for_soft_fail( + self, + event: EventBase, + state: Optional[Iterable[EventBase]], + backfilled: bool, + origin: str, + ) -> None: + """Checks if we should soft fail the event; if so, marks the event as + such. + + Args: + event + state: The state at the event if we don't have all the event's prev events + backfilled: Whether the event is from backfill + origin: The host the event originates from. + """ + # For new (non-backfilled and non-outlier) events we check if the event + # passes auth based on the current state. If it doesn't then we + # "soft-fail" the event. + if backfilled or event.internal_metadata.is_outlier(): + return + + extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids = set(extrem_ids_list) + prev_event_ids = set(event.prev_event_ids()) + + if extrem_ids == prev_event_ids: + # If they're the same then the current state is the same as the + # state at the event, so no point rechecking auth for soft fail. + return + + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + # Calculate the "current state". + if state is not None: + # If we're explicitly given the state then we won't have all the + # prev events, and so we have a gap in the graph. In this case + # we want to be a little careful as we might have been down for + # a while and have an incorrect view of the current state, + # however we still want to do checks as gaps are easy to + # maliciously manufacture. + # + # So we use a "current state" that is actually a state + # resolution across the current forward extremities and the + # given state at the event. This should correctly handle cases + # like bans, especially with state res v2. + + state_sets_d = await self.state_store.get_state_groups( + event.room_id, extrem_ids + ) + state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) + state_sets.append(state) + current_states = await self.state_handler.resolve_events( + room_version, state_sets, event + ) + current_state_ids: StateMap[str] = { + k: e.event_id for k, e in current_states.items() + } + else: + current_state_ids = await self.state_handler.get_current_state_ids( + event.room_id, latest_event_ids=extrem_ids + ) + + logger.debug( + "Doing soft-fail check for %s: state %s", + event.event_id, + current_state_ids, + ) + + # Now check if event pass auth against said current state + auth_types = auth_types_for_event(room_version_obj, event) + current_state_ids_list = [ + e for k, e in current_state_ids.items() if k in auth_types + ] + + auth_events_map = await self.store.get_events(current_state_ids_list) + current_auth_events = { + (e.type, e.state_key): e for e in auth_events_map.values() + } + + try: + event_auth.check(room_version_obj, event, auth_events=current_auth_events) + except AuthError as e: + logger.warning( + "Soft-failing %r (from %s) because %s", + event, + e, + origin, + extra={ + "room_id": event.room_id, + "mxid": event.sender, + "hs": origin, + }, + ) + soft_failed_event_counter.inc() + event.internal_metadata.soft_failed = True + + async def _update_auth_events_and_context_for_auth( + self, + origin: str, + event: EventBase, + context: EventContext, + input_auth_events: StateMap[EventBase], + ) -> Tuple[EventContext, StateMap[EventBase]]: + """Helper for _check_event_auth. See there for docs. + + Checks whether a given event has the expected auth events. If it + doesn't then we talk to the remote server to compare state to see if + we can come to a consensus (e.g. if one server missed some valid + state). + + This attempts to resolve any potential divergence of state between + servers, but is not essential and so failures should not block further + processing of the event. + + Args: + origin: + event: + context: + + input_auth_events: + Map from (event_type, state_key) to event + + Normally, our calculated auth_events based on the state of the room + at the event's position in the DAG, though occasionally (eg if the + event is an outlier), may be the auth events claimed by the remote + server. + + Returns: + updated context, updated auth event map + """ + # take a copy of input_auth_events before we modify it. + auth_events: MutableStateMap[EventBase] = dict(input_auth_events) + + event_auth_events = set(event.auth_event_ids()) + + # missing_auth is the set of the event's auth_events which we don't yet have + # in auth_events. + missing_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) + + # if we have missing events, we need to fetch those events from somewhere. + # + # we start by checking if they are in the store, and then try calling /event_auth/. + if missing_auth: + have_events = await self.store.have_seen_events(event.room_id, missing_auth) + logger.debug("Events %s are in the store", have_events) + missing_auth.difference_update(have_events) + + if missing_auth: + # If we don't have all the auth events, we need to get them. + logger.info("auth_events contains unknown events: %s", missing_auth) + try: + try: + remote_auth_chain = await self.federation_client.get_event_auth( + origin, event.room_id, event.event_id + ) + except RequestSendFailed as e1: + # The other side isn't around or doesn't implement the + # endpoint, so lets just bail out. + logger.info("Failed to get event auth from remote: %s", e1) + return context, auth_events + + seen_remotes = await self.store.have_seen_events( + event.room_id, [e.event_id for e in remote_auth_chain] + ) + + for e in remote_auth_chain: + if e.event_id in seen_remotes: + continue + + if e.event_id == event.event_id: + continue + + try: + auth_ids = e.auth_event_ids() + auth = { + (e.type, e.state_key): e + for e in remote_auth_chain + if e.event_id in auth_ids or e.type == EventTypes.Create + } + e.internal_metadata.outlier = True + + logger.debug( + "_check_event_auth %s missing_auth: %s", + event.event_id, + e.event_id, + ) + missing_auth_event_context = ( + await self.state_handler.compute_event_context(e) + ) + await self._auth_and_persist_event( + origin, + e, + missing_auth_event_context, + claimed_auth_event_map=auth, + ) + + if e.event_id in event_auth_events: + auth_events[(e.type, e.state_key)] = e + except AuthError: + pass + + except Exception: + logger.exception("Failed to get auth chain") + + if event.internal_metadata.is_outlier(): + # XXX: given that, for an outlier, we'll be working with the + # event's *claimed* auth events rather than those we calculated: + # (a) is there any point in this test, since different_auth below will + # obviously be empty + # (b) alternatively, why don't we do it earlier? + logger.info("Skipping auth_event fetch for outlier") + return context, auth_events + + different_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) + + if not different_auth: + return context, auth_events + + logger.info( + "auth_events refers to events which are not in our calculated auth " + "chain: %s", + different_auth, + ) + + # XXX: currently this checks for redactions but I'm not convinced that is + # necessary? + different_events = await self.store.get_events_as_list(different_auth) + + for d in different_events: + if d.room_id != event.room_id: + logger.warning( + "Event %s refers to auth_event %s which is in a different room", + event.event_id, + d.event_id, + ) + + # don't attempt to resolve the claimed auth events against our own + # in this case: just use our own auth events. + # + # XXX: should we reject the event in this case? It feels like we should, + # but then shouldn't we also do so if we've failed to fetch any of the + # auth events? + return context, auth_events + + # now we state-resolve between our own idea of the auth events, and the remote's + # idea of them. + + local_state = auth_events.values() + remote_auth_events = dict(auth_events) + remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) + remote_state = remote_auth_events.values() + + room_version = await self.store.get_room_version_id(event.room_id) + new_state = await self.state_handler.resolve_events( + room_version, (local_state, remote_state), event + ) + + logger.info( + "After state res: updating auth_events with new state %s", + { + (d.type, d.state_key): d.event_id + for d in new_state.values() + if auth_events.get((d.type, d.state_key)) != d + }, + ) + + auth_events.update(new_state) + + context = await self._update_context_for_auth_events( + event, context, auth_events + ) + + return context, auth_events + + async def _update_context_for_auth_events( + self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] + ) -> EventContext: + """Update the state_ids in an event context after auth event resolution, + storing the changes as a new state group. + + Args: + event: The event we're handling the context for + + context: initial event context + + auth_events: Events to update in the event context. + + Returns: + new event context + """ + # exclude the state key of the new event from the current_state in the context. + if event.is_state(): + event_key: Optional[Tuple[str, str]] = (event.type, event.state_key) + else: + event_key = None + state_updates = { + k: a.event_id for k, a in auth_events.items() if k != event_key + } + + current_state_ids = await context.get_current_state_ids() + current_state_ids = dict(current_state_ids) # type: ignore + + current_state_ids.update(state_updates) + + prev_state_ids = await context.get_prev_state_ids() + prev_state_ids = dict(prev_state_ids) + + prev_state_ids.update({k: a.event_id for k, a in auth_events.items()}) + + # create a new state group as a delta from the existing one. + prev_group = context.state_group + state_group = await self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=prev_group, + delta_ids=state_updates, + current_state_ids=current_state_ids, + ) + + return EventContext.with_state( + state_group=state_group, + state_group_before_event=context.state_group_before_event, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + prev_group=prev_group, + delta_ids=state_updates, + ) + + async def _run_push_actions_and_persist_event( + self, event: EventBase, context: EventContext, backfilled: bool = False + ): + """Run the push actions for a received event, and persist it. + + Args: + event: The event itself. + context: The event context. + backfilled: True if the event was backfilled. + """ + try: + if ( + not event.internal_metadata.is_outlier() + and not backfilled + and not context.rejected + and (await self.store.get_min_depth(event.room_id)) <= event.depth + ): + await self.action_generator.handle_push_actions_for_event( + event, context + ) + + await self.persist_events_and_notify( + event.room_id, [(event, context)], backfilled=backfilled + ) + except Exception: + run_in_background( + self.store.remove_push_actions_from_staging, event.event_id + ) + raise + + async def persist_events_and_notify( + self, + room_id: str, + event_and_contexts: Sequence[Tuple[EventBase, EventContext]], + backfilled: bool = False, + ) -> int: + """Persists events and tells the notifier/pushers about them, if + necessary. + + Args: + room_id: The room ID of events being persisted. + event_and_contexts: Sequence of events with their associated + context that should be persisted. All events must belong to + the same room. + backfilled: Whether these events are a result of + backfilling or not + + Returns: + The stream ID after which all events have been persisted. + """ + if not event_and_contexts: + return self.store.get_current_events_token() + + instance = self.config.worker.events_shard_config.get_instance(room_id) + if instance != self._instance_name: + # Limit the number of events sent over replication. We choose 200 + # here as that is what we default to in `max_request_body_size(..)` + for batch in batch_iter(event_and_contexts, 200): + result = await self._send_events( + instance_name=instance, + store=self.store, + room_id=room_id, + event_and_contexts=batch, + backfilled=backfilled, + ) + return result["max_stream_id"] + else: + assert self.storage.persistence + + # Note that this returns the events that were persisted, which may not be + # the same as were passed in if some were deduplicated due to transaction IDs. + events, max_stream_token = await self.storage.persistence.persist_events( + event_and_contexts, backfilled=backfilled + ) + + if self._ephemeral_messages_enabled: + for event in events: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + + if not backfilled: # Never notify for backfilled events + for event in events: + await self._notify_persisted_event(event, max_stream_token) + + return max_stream_token.stream + + async def _notify_persisted_event( + self, event: EventBase, max_stream_token: RoomStreamToken + ) -> None: + """Checks to see if notifier/pushers should be notified about the + event or not. + + Args: + event: + max_stream_token: The max_stream_id returned by persist_events + """ + + extra_users = [] + if event.type == EventTypes.Member: + target_user_id = event.state_key + + # We notify for memberships if its an invite for one of our + # users + if event.internal_metadata.is_outlier(): + if event.membership != Membership.INVITE: + if not self.is_mine_id(target_user_id): + return + + target_user = UserID.from_string(target_user_id) + extra_users.append(target_user) + elif event.internal_metadata.is_outlier(): + return + + # the event has been persisted so it should have a stream ordering. + assert event.internal_metadata.stream_ordering + + event_pos = PersistedEventPosition( + self._instance_name, event.internal_metadata.stream_ordering + ) + self.notifier.on_new_room_event( + event, event_pos, max_stream_token, extra_users=extra_users + ) + + def _sanity_check_event(self, ev: EventBase) -> None: + """ + Do some early sanity checks of a received event + + In particular, checks it doesn't have an excessive number of + prev_events or auth_events, which could cause a huge state resolution + or cascade of event fetches. + + Args: + ev: event to be checked + + Raises: + SynapseError if the event does not pass muster + """ + if len(ev.prev_event_ids()) > 20: + logger.warning( + "Rejecting event %s which has %i prev_events", + ev.event_id, + len(ev.prev_event_ids()), + ) + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events") + + if len(ev.auth_event_ids()) > 10: + logger.warning( + "Rejecting event %s which has %i auth_events", + ev.event_id, + len(ev.auth_event_ids()), + ) + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") + + async def get_min_depth_for_context(self, context: str) -> int: + return await self.store.get_min_depth(context) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 79cadb7b57..a0b3145f4e 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -62,7 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() self.storage = hs.get_storage() self.clock = hs.get_clock() - self.federation_handler = hs.get_federation_handler() + self.federation_event_handler = hs.get_federation_event_handler() @staticmethod async def _serialize_payload(store, room_id, event_and_contexts, backfilled): @@ -127,7 +127,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) - max_stream_id = await self.federation_handler.persist_events_and_notify( + max_stream_id = await self.federation_event_handler.persist_events_and_notify( room_id, event_and_contexts, backfilled ) diff --git a/synapse/server.py b/synapse/server.py index de6517663e..5adeeff61a 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -76,6 +76,7 @@ from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler from synapse.handlers.event_auth import EventAuthHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.federation import FederationHandler +from synapse.handlers.federation_event import FederationEventHandler from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler from synapse.handlers.identity import IdentityHandler from synapse.handlers.initial_sync import InitialSyncHandler @@ -546,6 +547,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_federation_handler(self) -> FederationHandler: return FederationHandler(self) + @cache_in_self + def get_federation_event_handler(self) -> FederationEventHandler: + return FederationEventHandler(self) + @cache_in_self def get_identity_handler(self) -> IdentityHandler: return IdentityHandler(self) diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 383214ab50..663960ff53 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -208,7 +208,7 @@ class FederationKnockingTestCase( async def _check_event_auth(origin, event, context, *args, **kwargs): return context - homeserver.get_federation_handler()._check_event_auth = _check_event_auth + homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth return super().prepare(reactor, clock, homeserver) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index c72a8972a3..6c67a16de9 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -130,7 +130,9 @@ class FederationTestCase(unittest.HomeserverTestCase): ) with LoggingContext("send_rejected"): - d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + d = run_in_background( + self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev + ) self.get_success(d) # that should have been rejected @@ -182,7 +184,9 @@ class FederationTestCase(unittest.HomeserverTestCase): ) with LoggingContext("send_rejected"): - d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + d = run_in_background( + self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev + ) self.get_success(d) # that should have been rejected @@ -311,7 +315,9 @@ class FederationTestCase(unittest.HomeserverTestCase): with LoggingContext("receive_pdu"): # Fake the OTHER_SERVER federating the message event over to our local homeserver d = run_in_background( - self.handler.on_receive_pdu, OTHER_SERVER, message_event + self.hs.get_federation_event_handler().on_receive_pdu, + OTHER_SERVER, + message_event, ) self.get_success(d) @@ -382,7 +388,9 @@ class FederationTestCase(unittest.HomeserverTestCase): join_event.signatures[other_server] = {"x": "y"} with LoggingContext("send_join"): d = run_in_background( - self.handler.on_send_membership_event, other_server, join_event + self.hs.get_federation_event_handler().on_send_membership_event, + other_server, + join_event, ) self.get_success(d) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 0a52bc8b72..671dc7d083 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -885,7 +885,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.federation_sender = hs.get_federation_sender() self.event_builder_factory = hs.get_event_builder_factory() - self.federation_handler = hs.get_federation_handler() + self.federation_event_handler = hs.get_federation_event_handler() self.presence_handler = hs.get_presence_handler() # self.event_builder_for_2 = EventBuilderFactory(hs) @@ -1026,7 +1026,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None) ) - self.get_success(self.federation_handler.on_receive_pdu(hostname, event)) + self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event)) # Check that it was successfully persisted. self.get_success(self.store.get_event(event.event_id)) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index af5dfca752..92a5b53e11 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -205,7 +205,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): def create_room_with_remote_server(self, user, token, remote_server="other_server"): room = self.helper.create_room_as(user, tok=token) store = self.hs.get_datastore() - federation = self.hs.get_federation_handler() + federation = self.hs.get_federation_event_handler() prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) room_version = self.get_success(store.get_room_version(room)) diff --git a/tests/test_federation.py b/tests/test_federation.py index 348fcb72a7..61c9d7c2ef 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -75,7 +75,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) self.handler = self.homeserver.get_federation_handler() - self.handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed( + federation_event_handler = self.homeserver.get_federation_event_handler() + federation_event_handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed( context ) self.client = self.homeserver.get_federation_client() @@ -85,7 +86,9 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Send the join, it should return None (which is not an error) self.assertEqual( - self.get_success(self.handler.on_receive_pdu("test.serv", join_event)), + self.get_success( + federation_event_handler.on_receive_pdu("test.serv", join_event) + ), None, ) @@ -129,9 +132,10 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) + federation_event_handler = self.homeserver.get_federation_event_handler() with LoggingContext("test-context"): failure = self.get_failure( - self.handler.on_receive_pdu("test.serv", lying_event), + federation_event_handler.on_receive_pdu("test.serv", lying_event), FederationError, ) -- cgit 1.5.1 From e62cdbef1a499f428e48f98167b2b709d16c671d Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 27 Aug 2021 11:16:40 +0200 Subject: Improve ServerNoticeServlet to avoid duplicate requests (#10679) Fixes: #9544 --- changelog.d/10679.bugfix | 1 + synapse/rest/admin/__init__.py | 5 +- synapse/rest/admin/server_notice_servlet.py | 19 +- synapse/server_notices/server_notices_manager.py | 17 +- tests/rest/admin/test_server_notice.py | 450 +++++++++++++++++++++++ 5 files changed, 475 insertions(+), 17 deletions(-) create mode 100644 changelog.d/10679.bugfix create mode 100644 tests/rest/admin/test_server_notice.py (limited to 'tests') diff --git a/changelog.d/10679.bugfix b/changelog.d/10679.bugfix new file mode 100644 index 0000000000..5c4061f6d5 --- /dev/null +++ b/changelog.d/10679.bugfix @@ -0,0 +1 @@ +Improve ServerNoticeServlet to avoid duplicate requests and add unit tests. \ No newline at end of file diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6e1c8736e1..b2514d9d0d 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -223,7 +223,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomMembersRestServlet(hs).register(http_server) DeleteRoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) - SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) UserMembershipRestServlet(hs).register(http_server) @@ -247,6 +246,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: NewRegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server) + # Some servlets only get registered for the main process. + if hs.config.worker_app is None: + SendServerNoticeServlet(hs).register(http_server) + def register_servlets_for_client_rest_resource( hs: "HomeServer", http_server: HttpServer diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index b5e4c474ef..42201afc86 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.constants import EventTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -53,6 +53,8 @@ class SendServerNoticeServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() + self.server_notices_manager = hs.get_server_notices_manager() + self.admin_handler = hs.get_admin_handler() self.txns = HttpTransactionCache(hs) def register(self, json_resource: HttpServer): @@ -79,19 +81,22 @@ class SendServerNoticeServlet(RestServlet): # We grab the server notices manager here as its initialisation has a check for worker processes, # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the # admin api). - if not self.hs.get_server_notices_manager().is_enabled(): + if not self.server_notices_manager.is_enabled(): raise SynapseError(400, "Server notices are not enabled on this server") - user_id = body["user_id"] - UserID.from_string(user_id) - if not self.hs.is_mine_id(user_id): + target_user = UserID.from_string(body["user_id"]) + if not self.hs.is_mine(target_user): raise SynapseError(400, "Server notices can only be sent to local users") - event = await self.hs.get_server_notices_manager().send_notice( - user_id=body["user_id"], + if not await self.admin_handler.get_user(target_user): + raise NotFoundError("User not found") + + event = await self.server_notices_manager.send_notice( + user_id=target_user.to_string(), type=event_type, state_key=state_key, event_content=body["content"], + txn_id=txn_id, ) return 200, {"event_id": event.event_id} diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index f19075b760..d87a538917 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -12,26 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.api.constants import EventTypes, Membership, RoomCreationPreset from synapse.events import EventBase from synapse.types import UserID, create_requester from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) SERVER_NOTICE_ROOM_TAG = "m.server_notice" class ServerNoticesManager: - def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ - + def __init__(self, hs: "HomeServer"): self._store = hs.get_datastore() self._config = hs.config self._account_data_handler = hs.get_account_data_handler() @@ -58,6 +55,7 @@ class ServerNoticesManager: event_content: dict, type: str = EventTypes.Message, state_key: Optional[str] = None, + txn_id: Optional[str] = None, ) -> EventBase: """Send a notice to the given user @@ -68,6 +66,7 @@ class ServerNoticesManager: event_content: content of event to send type: type of event is_state_event: Is the event a state event + txn_id: The transaction ID. """ room_id = await self.get_or_create_notice_room_for_user(user_id) await self.maybe_invite_user_to_room(user_id, room_id) @@ -90,7 +89,7 @@ class ServerNoticesManager: event_dict["state_key"] = state_key event, _ = await self._event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, ratelimit=False + requester, event_dict, ratelimit=False, txn_id=txn_id ) return event diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py new file mode 100644 index 0000000000..fbceba3254 --- /dev/null +++ b/tests/rest/admin/test_server_notice.py @@ -0,0 +1,450 @@ +# Copyright 2021 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login, room, sync +from synapse.storage.roommember import RoomsForUser +from synapse.types import JsonDict + +from tests import unittest +from tests.unittest import override_config + + +class ServerNoticeTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.room_shutdown_handler = hs.get_room_shutdown_handler() + self.pagination_handler = hs.get_pagination_handler() + self.server_notices_manager = self.hs.get_server_notices_manager() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + self.url = "/_synapse/admin/v1/send_server_notice" + + def test_no_auth(self): + """Try to send a server notice without authentication.""" + channel = self.make_request("POST", self.url) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """If the user is not a server admin, an error is returned.""" + channel = self.make_request( + "POST", + self.url, + access_token=self.other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_user_does_not_exist(self): + """Tests that a lookup for a user that does not exist returns a 404""" + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": "@unknown_person:test", "content": ""}, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": "@unknown_person:unknown_domain", + "content": "", + }, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual( + "Server notices can only be sent to local users", channel.json_body["error"] + ) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_invalid_parameter(self): + """If parameters are invalid, an error is returned.""" + + # no content, no user + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) + + # no content + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": self.other_user}, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # no body + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": self.other_user, "content": ""}, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual("'body' not in content", channel.json_body["error"]) + + # no msgtype + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": self.other_user, "content": {"body": ""}}, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual("'msgtype' not in content", channel.json_body["error"]) + + def test_server_notice_disabled(self): + """Tests that server returns error if server notice is disabled""" + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": "", + }, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual( + "Server notices are not enabled on this server", channel.json_body["error"] + ) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_send_server_notice(self): + """ + Tests that sending two server notices is successfully, + the server uses the same room and do not send messages twice. + """ + # user has no room memberships + self._check_invite_and_join_status(self.other_user, 0, 0) + + # send first message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg one"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join(room=room_id, user=self.other_user, tok=self.other_user_token) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(room_id, self.other_user_token) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + + # invalidate cache of server notices room_ids + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) + + # send second message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg two"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has no new invites or memberships + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(room_id, self.other_user_token) + + self.assertEqual(len(messages), 2) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + self.assertEqual(messages[1]["content"]["body"], "test msg two") + self.assertEqual(messages[1]["sender"], "@notices:test") + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_send_server_notice_leave_room(self): + """ + Tests that sending a server notices is successfully. + The user leaves the room and the second message appears + in a new room. + """ + # user has no room memberships + self._check_invite_and_join_status(self.other_user, 0, 0) + + # send first message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg one"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + first_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=first_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(first_room_id, self.other_user_token) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + + # user leaves the romm + self.helper.leave( + room=first_room_id, user=self.other_user, tok=self.other_user_token + ) + + # user is not member anymore + self._check_invite_and_join_status(self.other_user, 0, 0) + + # invalidate cache of server notices room_ids + # if server tries to send to a cached room_id the user gets the message + # in old room + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) + + # send second message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg two"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + second_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=second_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(second_room_id, self.other_user_token) + + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg two") + self.assertEqual(messages[0]["sender"], "@notices:test") + # room has the same id + self.assertNotEqual(first_room_id, second_room_id) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_send_server_notice_delete_room(self): + """ + Tests that the user get server notice in a new room + after the first server notice room was deleted. + """ + # user has no room memberships + self._check_invite_and_join_status(self.other_user, 0, 0) + + # send first message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg one"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + first_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=first_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(first_room_id, self.other_user_token) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + + # shut down and purge room + self.get_success( + self.room_shutdown_handler.shutdown_room(first_room_id, self.admin_user) + ) + self.get_success(self.pagination_handler.purge_room(first_room_id)) + + # user is not member anymore + self._check_invite_and_join_status(self.other_user, 0, 0) + + # It doesn't really matter what API we use here, we just want to assert + # that the room doesn't exist. + summary = self.get_success(self.store.get_room_summary(first_room_id)) + # The summary should be empty since the room doesn't exist. + self.assertEqual(summary, {}) + + # invalidate cache of server notices room_ids + # if server tries to send to a cached room_id it gives an error + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) + + # send second message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg two"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + second_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=second_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get message + messages = self._sync_and_get_messages(second_room_id, self.other_user_token) + + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg two") + self.assertEqual(messages[0]["sender"], "@notices:test") + # second room has new ID + self.assertNotEqual(first_room_id, second_room_id) + + def _check_invite_and_join_status( + self, user_id: str, expected_invites: int, expected_memberships: int + ) -> RoomsForUser: + """Check invite and room membership status of a user. + + Args + user_id: user to check + expected_invites: number of expected invites of this user + expected_memberships: number of expected room memberships of this user + Returns + room_ids from the rooms that the user is invited + """ + + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(user_id) + ) + self.assertEqual(expected_invites, len(invited_rooms)) + + room_ids = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(expected_memberships, len(room_ids)) + + return invited_rooms + + def _sync_and_get_messages(self, room_id: str, token: str) -> List[JsonDict]: + """ + Do a sync and get messages of a room. + + Args + room_id: room that contains the messages + token: access token of user + + Returns + list of messages contained in the room + """ + channel = self.make_request( + "GET", "/_matrix/client/r0/sync", access_token=token + ) + self.assertEqual(channel.code, 200) + + # Get the messages + room = channel.json_body["rooms"]["join"][room_id] + messages = [ + x for x in room["timeline"]["events"] if x["type"] == "m.room.message" + ] + return messages -- cgit 1.5.1 From 8c26f16c76b475e0ace7b58920d90368b180454c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 31 Aug 2021 12:56:22 +0100 Subject: Fix up unit tests (#10723) These were broken in an incorrect merge of GHSA-jj53-8fmw-f2w2 (cb35df9) --- changelog.d/10723.bugfix | 1 + tests/rest/client/v2_alpha/test_groups.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 changelog.d/10723.bugfix (limited to 'tests') diff --git a/changelog.d/10723.bugfix b/changelog.d/10723.bugfix new file mode 100644 index 0000000000..e6ffdc9512 --- /dev/null +++ b/changelog.d/10723.bugfix @@ -0,0 +1 @@ +Fix unauthorised exposure of room metadata to communities. diff --git a/tests/rest/client/v2_alpha/test_groups.py b/tests/rest/client/v2_alpha/test_groups.py index bfa9336baa..ad0425ae65 100644 --- a/tests/rest/client/v2_alpha/test_groups.py +++ b/tests/rest/client/v2_alpha/test_groups.py @@ -1,5 +1,18 @@ -from synapse.rest.client.v1 import room -from synapse.rest.client.v2_alpha import groups +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest.client import groups, room from tests import unittest from tests.unittest import override_config -- cgit 1.5.1