From fab495a9e1442d99e922367f65f41de5eaa488eb Mon Sep 17 00:00:00 2001 From: "DeepBlueV7.X" Date: Fri, 21 Oct 2022 08:49:47 +0000 Subject: Fix event size checks (#13710) --- synapse/event_auth.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'synapse') diff --git a/synapse/event_auth.py b/synapse/event_auth.py index bab31e33c5..5036604036 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -342,15 +342,15 @@ def check_state_dependent_auth_rules( def _check_size_limits(event: "EventBase") -> None: - if len(event.user_id) > 255: + if len(event.user_id.encode("utf-8")) > 255: raise EventSizeError("'user_id' too large") - if len(event.room_id) > 255: + if len(event.room_id.encode("utf-8")) > 255: raise EventSizeError("'room_id' too large") - if event.is_state() and len(event.state_key) > 255: + if event.is_state() and len(event.state_key.encode("utf-8")) > 255: raise EventSizeError("'state_key' too large") - if len(event.type) > 255: + if len(event.type.encode("utf-8")) > 255: raise EventSizeError("'type' too large") - if len(event.event_id) > 255: + if len(event.event_id.encode("utf-8")) > 255: raise EventSizeError("'event_id' too large") if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE: raise EventSizeError("event too large") -- cgit 1.4.1 From 1433b5d5b64c3a6624e6e4ff4fef22127c49df86 Mon Sep 17 00:00:00 2001 From: Tadeusz Sośnierz Date: Fri, 21 Oct 2022 14:52:44 +0200 Subject: Show erasure status when listing users in the Admin API (#14205) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Show erasure status when listing users in the Admin API * Use USING when joining erased_users * Add changelog entry * Revert "Use USING when joining erased_users" This reverts commit 30bd2bf106415caadcfdbdd1b234ef2b106cc394. * Make the erased check work on postgres * Add a testcase for showing erased user status * Appease the style linter * Explicitly convert `erased` to bool to make SQLite consistent with Postgres This also adds us an easy way in to fix the other accidentally integered columns. * Move erasure status test to UsersListTestCase * Include user erased status when fetching user info via the admin API * Document the erase status in user_admin_api * Appease the linter and mypy * Signpost comments in tests Co-authored-by: Tadeusz Sośnierz Co-authored-by: David Robertson --- changelog.d/14205.feature | 1 + docs/admin_api/user_admin_api.md | 4 ++++ synapse/handlers/admin.py | 1 + synapse/storage/databases/main/__init__.py | 13 +++++++++-- tests/rest/admin/test_user.py | 35 +++++++++++++++++++++++++++++- 5 files changed, 51 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14205.feature (limited to 'synapse') diff --git a/changelog.d/14205.feature b/changelog.d/14205.feature new file mode 100644 index 0000000000..6692063352 --- /dev/null +++ b/changelog.d/14205.feature @@ -0,0 +1 @@ +Show erasure status when listing users in the Admin API. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 3625c7b6c5..c95d6c9b05 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -37,6 +37,7 @@ It returns a JSON body like the following: "is_guest": 0, "admin": 0, "deactivated": 0, + "erased": false, "shadow_banned": 0, "creation_ts": 1560432506, "appservice_id": null, @@ -167,6 +168,7 @@ A response body like the following is returned: "admin": 0, "user_type": null, "deactivated": 0, + "erased": false, "shadow_banned": 0, "displayname": "", "avatar_url": null, @@ -177,6 +179,7 @@ A response body like the following is returned: "admin": 1, "user_type": null, "deactivated": 0, + "erased": false, "shadow_banned": 0, "displayname": "", "avatar_url": "", @@ -247,6 +250,7 @@ The following fields are returned in the JSON response body: - `user_type` - string - Type of the user. Normal users are type `None`. This allows user type specific behaviour. There are also types `support` and `bot`. - `deactivated` - bool - Status if that user has been marked as deactivated. + - `erased` - bool - Status if that user has been marked as erased. - `shadow_banned` - bool - Status if that user has been marked as shadow banned. - `displayname` - string - The user's display name if they have set one. - `avatar_url` - string - The user's avatar URL if they have set one. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index f2989cc4a2..5bf8e86387 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -100,6 +100,7 @@ class AdminHandler: user_info_dict["avatar_url"] = profile.avatar_url user_info_dict["threepids"] = threepids user_info_dict["external_ids"] = external_ids + user_info_dict["erased"] = await self.store.is_user_erased(user.to_string()) return user_info_dict diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index a62b4abd4e..cfaedf5e0c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -201,7 +201,7 @@ class DataStore( name: Optional[str] = None, guests: bool = True, deactivated: bool = False, - order_by: str = UserSortOrder.USER_ID.value, + order_by: str = UserSortOrder.NAME.value, direction: str = "f", approved: bool = True, ) -> Tuple[List[JsonDict], int]: @@ -261,6 +261,7 @@ class DataStore( sql_base = f""" FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? + LEFT JOIN erased_users AS eu ON u.name = eu.user_id {where_clause} """ sql = "SELECT COUNT(*) as total_users " + sql_base @@ -269,7 +270,8 @@ class DataStore( sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, - displayname, avatar_url, creation_ts * 1000 as creation_ts, approved + displayname, avatar_url, creation_ts * 1000 as creation_ts, approved, + eu.user_id is not null as erased {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? @@ -277,6 +279,13 @@ class DataStore( args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) + + # some of those boolean values are returned as integers when we're on SQLite + columns_to_boolify = ["erased"] + for user in users: + for column in columns_to_boolify: + user[column] = bool(user[column]) + return users, count return await self.db_pool.runInteraction( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4c1ce33463..63410ffdf1 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -31,7 +31,7 @@ from synapse.api.room_versions import RoomVersions from synapse.rest.client import devices, login, logout, profile, register, room, sync from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests import unittest @@ -924,6 +924,36 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids) self.assertEqual(not_approved_user, non_admin_user_ids[0]) + def test_erasure_status(self) -> None: + # Create a new user. + user_id = self.register_user("eraseme", "eraseme") + + # They should appear in the list users API, marked as not erased. + channel = self.make_request( + "GET", + self.url + "?deactivated=true", + access_token=self.admin_user_tok, + ) + users = {user["name"]: user for user in channel.json_body["users"]} + self.assertIs(users[user_id]["erased"], False) + + # Deactivate that user, requesting erasure. + deactivate_account_handler = self.hs.get_deactivate_account_handler() + self.get_success( + deactivate_account_handler.deactivate_account( + user_id, erase_data=True, requester=create_requester(user_id) + ) + ) + + # Repeat the list users query. They should now be marked as erased. + channel = self.make_request( + "GET", + self.url + "?deactivated=true", + access_token=self.admin_user_tok, + ) + users = {user["name"]: user for user in channel.json_body["users"]} + self.assertIs(users[user_id]["erased"], True) + def _order_test( self, expected_user_list: List[str], @@ -1195,6 +1225,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User1", channel.json_body["displayname"]) + self.assertFalse(channel.json_body["erased"]) # Deactivate and erase user channel = self.make_request( @@ -1219,6 +1250,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(0, len(channel.json_body["threepids"])) self.assertIsNone(channel.json_body["avatar_url"]) self.assertIsNone(channel.json_body["displayname"]) + self.assertTrue(channel.json_body["erased"]) self._is_erased("@user:test", True) @@ -2757,6 +2789,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertIn("avatar_url", content) self.assertIn("admin", content) self.assertIn("deactivated", content) + self.assertIn("erased", content) self.assertIn("shadow_banned", content) self.assertIn("creation_ts", content) self.assertIn("appservice_id", content) -- cgit 1.4.1 From 4dd7aa371b6bc746fa4b0a9af220b2013b17a45d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Oct 2022 09:11:19 -0400 Subject: Properly update the threads table when thread events are redacted. (#14248) When the last event in a thread is redacted we need to update the threads table: * Find the new latest event in the thread and store it into the table; or * Remove the thread from the table if it is no longer a thread (i.e. all events in the thread were redacted). --- changelog.d/14248.bugfix | 1 + synapse/storage/databases/main/events.py | 61 ++++++++++++++--- tests/rest/client/test_relations.py | 110 +++++++++++++++++++++---------- 3 files changed, 129 insertions(+), 43 deletions(-) create mode 100644 changelog.d/14248.bugfix (limited to 'synapse') diff --git a/changelog.d/14248.bugfix b/changelog.d/14248.bugfix new file mode 100644 index 0000000000..203c52c16b --- /dev/null +++ b/changelog.d/14248.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0rc1 where the information returned from the `/threads` API could be stale when threaded events are redacted. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 6698cbf664..00880bb37d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2028,25 +2028,37 @@ class PersistEventsStore: redacted_event_id: The event that was redacted. """ - # Fetch the current relation of the event being redacted. - redacted_relates_to = self.db_pool.simple_select_one_onecol_txn( + # Fetch the relation of the event being redacted. + row = self.db_pool.simple_select_one_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id}, - retcol="relates_to_id", + retcols=("relates_to_id", "relation_type"), allow_none=True, ) + # Nothing to do if no relation is found. + if row is None: + return + + redacted_relates_to = row["relates_to_id"] + rel_type = row["relation_type"] + self.db_pool.simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} + ) + # Any relation information for the related event must be cleared. - if redacted_relates_to is not None: - self.store._invalidate_cache_and_stream( - txn, self.store.get_relations_for_event, (redacted_relates_to,) - ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_relations_for_event, (redacted_relates_to,) + ) + if rel_type == RelationTypes.ANNOTATION: self.store._invalidate_cache_and_stream( txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REPLACE: self.store._invalidate_cache_and_stream( txn, self.store.get_applicable_edit, (redacted_relates_to,) ) + if rel_type == RelationTypes.THREAD: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_summary, (redacted_relates_to,) ) @@ -2057,9 +2069,38 @@ class PersistEventsStore: txn, self.store.get_threads, (room_id,) ) - self.db_pool.simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) + # Find the new latest event in the thread. + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND relation_type = ? + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT 1 + """ + txn.execute(sql, (redacted_relates_to, RelationTypes.THREAD)) + + # If a latest event is found, update the threads table, this might + # be the same current latest event (if an earlier event in the thread + # was redacted). + latest_event_row = txn.fetchone() + if latest_event_row: + self.db_pool.simple_upsert_txn( + txn, + table="threads", + keyvalues={"room_id": room_id, "thread_id": redacted_relates_to}, + values={ + "latest_event_id": latest_event_row[0], + "topological_ordering": latest_event_row[1], + "stream_ordering": latest_event_row[2], + }, + ) + + # Otherwise, delete the thread: it no longer exists. + else: + self.db_pool.simple_delete_one_txn( + txn, table="threads", keyvalues={"thread_id": redacted_relates_to} + ) def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("topic"), str): diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index ddf315b894..e3d801f7a8 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1523,6 +1523,26 @@ class RelationRedactionTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) + def _get_threads(self) -> List[Tuple[str, str]]: + """Request the threads in the room and returns a list of thread ID and latest event ID.""" + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + threads = channel.json_body["chunk"] + return [ + ( + t["event_id"], + t["unsigned"]["m.relations"][RelationTypes.THREAD]["latest_event"][ + "event_id" + ], + ) + for t in threads + ] + def test_redact_relation_annotation(self) -> None: """ Test that annotations of an event are properly handled after the @@ -1567,58 +1587,82 @@ class RelationRedactionTestCase(BaseRelationsTestCase): The redacted event should not be included in bundled aggregations or the response to relations. """ - channel = self._send_relation( - RelationTypes.THREAD, - EventTypes.Message, - content={"body": "reply 1", "msgtype": "m.text"}, - ) - unredacted_event_id = channel.json_body["event_id"] + # Create a thread with a few events in it. + thread_replies = [] + for i in range(3): + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": f"reply {i}", "msgtype": "m.text"}, + ) + thread_replies.append(channel.json_body["event_id"]) - # Note that the *last* event in the thread is redacted, as that gets - # included in the bundled aggregation. - channel = self._send_relation( - RelationTypes.THREAD, - EventTypes.Message, - content={"body": "reply 2", "msgtype": "m.text"}, + ################################################## + # Check the test data is configured as expected. # + ################################################## + self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) + relations = self._get_bundled_aggregations() + self.assertDictContainsSubset( + {"count": 3, "current_user_participated": True}, + relations[RelationTypes.THREAD], + ) + # The latest event is the last sent event. + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + thread_replies[-1], ) - to_redact_event_id = channel.json_body["event_id"] - # Both relations exist. - event_ids = self._get_related_events() + # There should be one thread, the latest event is the event that will be redacted. + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) + + ########################## + # Redact the last event. # + ########################## + self._redact(thread_replies.pop()) + + # The thread should still exist, but the latest event should be updated. + self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) self.assertDictContainsSubset( - { - "count": 2, - "current_user_participated": True, - }, + {"count": 2, "current_user_participated": True}, relations[RelationTypes.THREAD], ) - # And the latest event returned is the event that will be redacted. + # And the latest event is the last unredacted event. self.assertEqual( relations[RelationTypes.THREAD]["latest_event"]["event_id"], - to_redact_event_id, + thread_replies[-1], ) + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) - # Redact one of the reactions. - self._redact(to_redact_event_id) + ########################################### + # Redact the *first* event in the thread. # + ########################################### + self._redact(thread_replies.pop(0)) - # The unredacted relation should still exist. - event_ids = self._get_related_events() + # Nothing should have changed (except the thread count). + self.assertEquals(self._get_related_events(), thread_replies) relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [unredacted_event_id]) self.assertDictContainsSubset( - { - "count": 1, - "current_user_participated": True, - }, + {"count": 1, "current_user_participated": True}, relations[RelationTypes.THREAD], ) - # And the latest event is now the unredacted event. + # And the latest event is the last unredacted event. self.assertEqual( relations[RelationTypes.THREAD]["latest_event"]["event_id"], - unredacted_event_id, + thread_replies[-1], ) + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) + + #################################### + # Redact the last remaining event. # + #################################### + self._redact(thread_replies.pop(0)) + self.assertEquals(thread_replies, []) + + # The event should no longer be considered a thread. + self.assertEquals(self._get_related_events(), []) + self.assertEquals(self._get_bundled_aggregations(), {}) + self.assertEqual(self._get_threads(), []) def test_redact_parent_edit(self) -> None: """Test that edits of an event are redacted when the original event -- cgit 1.4.1 From d24346f53055eae7fb8e9038ef35fa843790742b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 21 Oct 2022 16:03:44 +0100 Subject: Fix logging error on SIGHUP (#14258) --- changelog.d/14258.bugfix | 2 ++ synapse/app/_base.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14258.bugfix (limited to 'synapse') diff --git a/changelog.d/14258.bugfix b/changelog.d/14258.bugfix new file mode 100644 index 0000000000..de97945844 --- /dev/null +++ b/changelog.d/14258.bugfix @@ -0,0 +1,2 @@ +Fix a bug introduced in Synapse 1.60.0 which caused an error to be logged when Synapse received a SIGHUP signal, and debug logging was enabled. + diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 000912e86e..a683ebf4cb 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -558,7 +558,7 @@ def reload_cache_config(config: HomeServerConfig) -> None: logger.warning(f) else: logger.debug( - "New cache config. Was:\n %s\nNow:\n", + "New cache config. Was:\n %s\nNow:\n %s", previous_cache_config.__dict__, config.caches.__dict__, ) -- cgit 1.4.1 From 1d45ad8b2ab1c41dd489ccd581d027077bc917e5 Mon Sep 17 00:00:00 2001 From: Germain Date: Fri, 21 Oct 2022 18:44:00 +0100 Subject: Improve aesthetics and reusability of HTML templates. (#13652) Use a base template to create a cohesive feel across the HTML templates provided by Synapse. Adds basic styling to the base template for a more user-friendly look and feel. --- changelog.d/13652.feature | 1 + synapse/res/templates/_base.html | 29 ++ .../res/templates/account_previously_renewed.html | 18 +- synapse/res/templates/account_renewed.html | 18 +- synapse/res/templates/add_threepid.html | 22 +- synapse/res/templates/add_threepid_failure.html | 20 +- synapse/res/templates/add_threepid_success.html | 18 +- synapse/res/templates/auth_success.html | 28 +- synapse/res/templates/invalid_token.html | 17 +- synapse/res/templates/notice_expiry.html | 93 +++--- synapse/res/templates/notif_mail.html | 116 ++++--- synapse/res/templates/password_reset.html | 19 +- .../res/templates/password_reset_confirmation.html | 14 +- synapse/res/templates/password_reset_failure.html | 14 +- synapse/res/templates/password_reset_success.html | 12 +- synapse/res/templates/recaptcha.html | 19 +- synapse/res/templates/registration.html | 21 +- synapse/res/templates/registration_failure.html | 12 +- synapse/res/templates/registration_success.html | 13 +- synapse/res/templates/registration_token.html | 16 +- synapse/res/templates/sso_account_deactivated.html | 49 ++- .../res/templates/sso_auth_account_details.html | 372 ++++++++++----------- synapse/res/templates/sso_auth_bad_user.html | 52 ++- synapse/res/templates/sso_auth_confirm.html | 56 ++-- synapse/res/templates/sso_auth_success.html | 54 ++- synapse/res/templates/sso_error.html | 34 +- synapse/res/templates/sso_login_idp_picker.html | 114 +++---- synapse/res/templates/sso_new_user_consent.html | 60 ++-- synapse/res/templates/sso_redirect_confirm.html | 75 ++--- synapse/res/templates/style.css | 29 ++ synapse/res/templates/terms.html | 16 +- 31 files changed, 691 insertions(+), 740 deletions(-) create mode 100644 changelog.d/13652.feature create mode 100644 synapse/res/templates/_base.html create mode 100644 synapse/res/templates/style.css (limited to 'synapse') diff --git a/changelog.d/13652.feature b/changelog.d/13652.feature new file mode 100644 index 0000000000..bc7f2926dc --- /dev/null +++ b/changelog.d/13652.feature @@ -0,0 +1 @@ +Improve aesthetics of HTML templates. Note that these changes do not retroactively apply to templates which have been [customised](https://matrix-org.github.io/synapse/latest/templates.html#templates) by server admins. \ No newline at end of file diff --git a/synapse/res/templates/_base.html b/synapse/res/templates/_base.html new file mode 100644 index 0000000000..46439fce6a --- /dev/null +++ b/synapse/res/templates/_base.html @@ -0,0 +1,29 @@ + + + + + + + {% block title %}{% endblock %} + + {% block header %}{% endblock %} + + +
+ {% if app_name == "Riot" %} + [Riot] + {% elif app_name == "Vector" %} + [Vector] + {% elif app_name == "Element" %} + [Element] + {% else %} + [matrix] + {% endif %} +
+ +{% block body %}{% endblock %} + + + diff --git a/synapse/res/templates/account_previously_renewed.html b/synapse/res/templates/account_previously_renewed.html index bd4f7cea97..91582a8af0 100644 --- a/synapse/res/templates/account_previously_renewed.html +++ b/synapse/res/templates/account_previously_renewed.html @@ -1,12 +1,6 @@ - - - - - - - Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}. - - - Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}. - - \ No newline at end of file +{% extends "_base.html" %} +{% block title %}Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.{% endblock %} + +{% block body %} +

Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.

+{% endblock %} diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html index 57b319f375..18a57833f1 100644 --- a/synapse/res/templates/account_renewed.html +++ b/synapse/res/templates/account_renewed.html @@ -1,12 +1,6 @@ - - - - - - - Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}. - - - Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}. - - \ No newline at end of file +{% extends "_base.html" %} +{% block title %}Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.{% endblock %} + +{% block body %} +

Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.

+{% endblock %} diff --git a/synapse/res/templates/add_threepid.html b/synapse/res/templates/add_threepid.html index 71f2215b7a..33c883936a 100644 --- a/synapse/res/templates/add_threepid.html +++ b/synapse/res/templates/add_threepid.html @@ -1,14 +1,8 @@ - - - - - - - Request to add an email address to your Matrix account - - -

A request to add an email address to your Matrix account has been received. If this was you, please click the link below to confirm adding this email:

- {{ link }} -

If this was not you, you can safely ignore this email. Thank you.

- - +{% extends "_base.html" %} +{% block title %}Request to add an email address to your Matrix account{% endblock %} + +{% block body %} +

A request to add an email address to your Matrix account has been received. If this was you, please click the link below to confirm adding this email:

+{{ link }} +

If this was not you, you can safely ignore this email. Thank you.

+{% endblock %} diff --git a/synapse/res/templates/add_threepid_failure.html b/synapse/res/templates/add_threepid_failure.html index bd627ee9ce..f6d7e33825 100644 --- a/synapse/res/templates/add_threepid_failure.html +++ b/synapse/res/templates/add_threepid_failure.html @@ -1,13 +1,7 @@ - - - - - - - Request failed - - -

The request failed for the following reason: {{ failure_reason }}.

-

No changes have been made to your account.

- - +{% extends "_base.html" %} +{% block title %}Request failed{% endblock %} + +{% block body %} +

The request failed for the following reason: {{ failure_reason }}.

+

No changes have been made to your account.

+{% endblock %} diff --git a/synapse/res/templates/add_threepid_success.html b/synapse/res/templates/add_threepid_success.html index 49170c138e..6d45111796 100644 --- a/synapse/res/templates/add_threepid_success.html +++ b/synapse/res/templates/add_threepid_success.html @@ -1,12 +1,6 @@ - - - - - - - Your email has now been validated - - -

Your email has now been validated, please return to your client. You may now close this window.

- - \ No newline at end of file +{% extends "_base.html" %} +{% block title %}Your email has now been validated{% endblock %} + +{% block body %} +

Your email has now been validated, please return to your client. You may now close this window.

+{% endblock %} diff --git a/synapse/res/templates/auth_success.html b/synapse/res/templates/auth_success.html index 2d6ac44a0e..9178332f59 100644 --- a/synapse/res/templates/auth_success.html +++ b/synapse/res/templates/auth_success.html @@ -1,21 +1,21 @@ - - -Success! - - +{% extends "_base.html" %} +{% block title %}Success!{% endblock %} + +{% block header %} - - -
-

Thank you

-

You may now close this window and return to the application

-
- - +{% endblock %} + +{% block body %} +
+

Thank you

+

You may now close this window and return to the application

+
+ +{% endblock %} diff --git a/synapse/res/templates/invalid_token.html b/synapse/res/templates/invalid_token.html index 2c7c384fe3..d0b1dae669 100644 --- a/synapse/res/templates/invalid_token.html +++ b/synapse/res/templates/invalid_token.html @@ -1,12 +1,5 @@ - - - - - - - Invalid renewal token. - - - Invalid renewal token. - - +{% block title %}Invalid renewal token.{% endblock %} + +{% block body %} +

Invalid renewal token.

+{% endblock %} diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html index 865f9f7ada..406397aaca 100644 --- a/synapse/res/templates/notice_expiry.html +++ b/synapse/res/templates/notice_expiry.html @@ -1,47 +1,46 @@ - - - - - - - - - - - - - - -
- - - - - - - - -
-
Hi {{ display_name }},
-
-
Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.
-
To extend the validity of your account, please click on the link below (or copy and paste it into a new browser tab):
- -
-
- - +{% extends "_base.html" %} +{% block title %}Notice of expiry{% endblock %} + +{% block header %} + +{% endblock %} + +{% block body %} + + + + + + +
+ + + + + + + + +
+
Hi {{ display_name }},
+
+
Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.
+
To extend the validity of your account, please click on the link below (or copy and paste it into a new browser tab):
+ +
+
+{% endblock %} diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html index 9dba0c0253..939d40315f 100644 --- a/synapse/res/templates/notif_mail.html +++ b/synapse/res/templates/notif_mail.html @@ -1,59 +1,57 @@ - - - - - - - - - - - - - - -
- - - - - -
-
Hi {{ user_display_name }},
-
{{ summary_text }}
-
- {%- for room in rooms %} - {%- include 'room.html' with context %} - {%- endfor %} - -
- - +{% block title %}New activity in room{% endblock %} + +{% block header %} + +{% endblock %} + +{% block body %} + + + + + + +
+ + + + + +
+
Hi {{ user_display_name }},
+
{{ summary_text }}
+
+ {%- for room in rooms %} + {%- include 'room.html' with context %} + {%- endfor %} + +
+{% endblock %} diff --git a/synapse/res/templates/password_reset.html b/synapse/res/templates/password_reset.html index a8bdce357b..de5a9ec68f 100644 --- a/synapse/res/templates/password_reset.html +++ b/synapse/res/templates/password_reset.html @@ -1,14 +1,9 @@ - - - Password reset - - - - -

A password reset request has been received for your Matrix account. If this was you, please click the link below to confirm resetting your password:

+{% block title %}Password reset{% endblock %} - {{ link }} +{% block body %} +

A password reset request has been received for your Matrix account. If this was you, please click the link below to confirm resetting your password:

-

If this was not you, do not click the link above and instead contact your server administrator. Thank you.

- - +{{ link }} + +

If this was not you, do not click the link above and instead contact your server administrator. Thank you.

+{% endblock %} diff --git a/synapse/res/templates/password_reset_confirmation.html b/synapse/res/templates/password_reset_confirmation.html index 2e3fd2ec1e..0eac64b6a8 100644 --- a/synapse/res/templates/password_reset_confirmation.html +++ b/synapse/res/templates/password_reset_confirmation.html @@ -1,10 +1,6 @@ - - - Password reset confirmation - - - - +{% block title %}Password reset confirmation{% endblock %} + +{% block body %}
@@ -15,6 +11,4 @@ If you did not mean to do this, please close this page and your password will not be changed.

- - - +{% endblock %} diff --git a/synapse/res/templates/password_reset_failure.html b/synapse/res/templates/password_reset_failure.html index 2d59c463f0..977babdb40 100644 --- a/synapse/res/templates/password_reset_failure.html +++ b/synapse/res/templates/password_reset_failure.html @@ -1,12 +1,6 @@ - - - Password reset failure - - - - -

The request failed for the following reason: {{ failure_reason }}.

+{% block title %}Password reset failure{% endblock %} +{% block body %} +

The request failed for the following reason: {{ failure_reason }}.

Your password has not been reset.

- - +{% endblock %} diff --git a/synapse/res/templates/password_reset_success.html b/synapse/res/templates/password_reset_success.html index 5165bd1fa2..0e99fad7ff 100644 --- a/synapse/res/templates/password_reset_success.html +++ b/synapse/res/templates/password_reset_success.html @@ -1,9 +1,5 @@ - - - - - - +{% block title %}Password reset success{% endblock %} + +{% block body %}

Your email has now been validated, please return to your client to reset your password. You may now close this window.

- - +{% endblock %} diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html index 615d3239c6..feaf3f6aed 100644 --- a/synapse/res/templates/recaptcha.html +++ b/synapse/res/templates/recaptcha.html @@ -1,10 +1,7 @@ - - -Authentication - - - +{% block title %}Authentication{% endblock %} + +{% block header %} + - - +{% endblock %} + +{% block body %}
{% if error is defined %} @@ -37,5 +35,4 @@ function captchaDone() {
- - +{% endblock %} \ No newline at end of file diff --git a/synapse/res/templates/registration.html b/synapse/res/templates/registration.html index 20e831ff4a..189960a832 100644 --- a/synapse/res/templates/registration.html +++ b/synapse/res/templates/registration.html @@ -1,16 +1,11 @@ - - - Registration - - - - -

You have asked us to register this email with a new Matrix account. If this was you, please click the link below to confirm your email address:

+{% block title %}Registration{% endblock %} - Verify Your Email Address +{% block body %} +

You have asked us to register this email with a new Matrix account. If this was you, please click the link below to confirm your email address:

-

If this was not you, you can safely disregard this email.

+Verify Your Email Address -

Thank you.

- - +

If this was not you, you can safely disregard this email.

+ +

Thank you.

+{% endblock %} diff --git a/synapse/res/templates/registration_failure.html b/synapse/res/templates/registration_failure.html index a6ed22bc90..3debe9301d 100644 --- a/synapse/res/templates/registration_failure.html +++ b/synapse/res/templates/registration_failure.html @@ -1,9 +1,5 @@ - - - - - - +{% block title %}Registration failure{% endblock %} + +{% block body %}

Validation failed for the following reason: {{ failure_reason }}.

- - +{% endblock %} diff --git a/synapse/res/templates/registration_success.html b/synapse/res/templates/registration_success.html index d51d5549d8..e2dd020a9e 100644 --- a/synapse/res/templates/registration_success.html +++ b/synapse/res/templates/registration_success.html @@ -1,10 +1,5 @@ - - - Your email has now been validated - - - - +{% block title %}Your email has now been validated{% endblock %} + +{% block body %}

Your email has now been validated, please return to your client. You may now close this window.

- - +{% endblock %} diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html index 59a98f564c..2ee5866ba5 100644 --- a/synapse/res/templates/registration_token.html +++ b/synapse/res/templates/registration_token.html @@ -1,11 +1,10 @@ - - -Authentication - - +{% block title %}Authentication{% endblock %} + +{% block header %} - - +{% endblock %} + +{% block body %}
{% if error is defined %} @@ -19,5 +18,4 @@
- - +{% endblock %} diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html index 075f801cec..c634229840 100644 --- a/synapse/res/templates/sso_account_deactivated.html +++ b/synapse/res/templates/sso_account_deactivated.html @@ -1,25 +1,24 @@ - - - - - SSO account deactivated - - - - -
-

Your account has been deactivated

-

- No account found -

-

- Your account might have been deactivated by the server administrator. - You can either try to create a new account or contact the server’s - administrator. -

-
- {% include "sso_footer.html" without context %} - - +{% block title %}SSO account deactivated{% endblock %} + +{% block header %} + +{% endblock %} + +{% block body %} +
+
+

Your account has been deactivated

+

+ No account found +

+

+ Your account might have been deactivated by the server administrator. + You can either try to create a new account or contact the server’s + administrator. +

+
+
+{% include "sso_footer.html" without context %} +{% endblock %} diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html index 2d1db386e1..b516333373 100644 --- a/synapse/res/templates/sso_auth_account_details.html +++ b/synapse/res/templates/sso_auth_account_details.html @@ -1,189 +1,185 @@ - - - - Create your account - - - - - - - -
-

Create your account

-

This is required. Continue to create your account on {{ server_name }}. You can't change this later.

-
-
-
-
- -
@
- -
:{{ server_name }}
+{% block title %}Create your account{% endblock %} + +{% block header %} + + +{% endblock %} + +{% block body %} +
+

Create your account

+

This is required. Continue to create your account on {{ server_name }}. You can't change this later.

+
+
+ +
+ +
@
+ +
:{{ server_name }}
+
+ + + {% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %} +
+

{% if idp.idp_icon %}{% endif %}Optional data from {{ idp.idp_name }}

+ {% if user_attributes.avatar_url %} +
- {% include "sso_footer.html" without context %} - - - + + + {% endif %} + {% if user_attributes.display_name %} + + {% endif %} + {% for email in user_attributes.emails %} + + {% endfor %} + + {% endif %} + +
+{% include "sso_footer.html" without context %} + +{% endblock %} diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html index 94403fc3ce..69fdcc9ef0 100644 --- a/synapse/res/templates/sso_auth_bad_user.html +++ b/synapse/res/templates/sso_auth_bad_user.html @@ -1,27 +1,25 @@ - - - - - Authentication failed - - - - - -
-

That doesn't look right

-

- We were unable to validate your {{ server_name }} account - via single sign‑on (SSO), because the SSO Identity - Provider returned different details than when you logged in. -

-

- Try the operation again, and ensure that you use the same details on - the Identity Provider as when you log into your account. -

-
- {% include "sso_footer.html" without context %} - - +{% block title %}Authentication failed{% endblock %} + +{% block header %} + +{% endblock %} + +{% block body %} +
+
+

That doesn't look right

+

+ We were unable to validate your {{ server_name }} account + via single sign‑on (SSO), because the SSO Identity + Provider returned different details than when you logged in. +

+

+ Try the operation again, and ensure that you use the same details on + the Identity Provider as when you log into your account. +

+
+
+{% include "sso_footer.html" without context %} +{% endblock %} diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html index aa1c974a6b..2d106e0ae4 100644 --- a/synapse/res/templates/sso_auth_confirm.html +++ b/synapse/res/templates/sso_auth_confirm.html @@ -1,30 +1,26 @@ - - - - - Confirm it's you - - - - - -
-

Confirm it's you to continue

-

- A client is trying to {{ description }}. To confirm this action - re-authorize your account with single sign-on. -

-

- If you did not expect this, your account may be compromised. -

-
-
- - Continue with {{ idp.idp_name }} - -
- {% include "sso_footer.html" without context %} - - +{% block title %}Confirm it's you{% endblock %} + +{% block header %} + +{% endblock %} + +{% block body %} +
+

Confirm it's you to continue

+

+ A client is trying to {{ description }}. To confirm this action + re-authorize your account with single sign-on. +

+

+ If you did not expect this, your account may be compromised. +

+
+
+ + Continue with {{ idp.idp_name }} + +
+{% include "sso_footer.html" without context %} +{% endblock %} diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html index 4898af6011..56150eaefe 100644 --- a/synapse/res/templates/sso_auth_success.html +++ b/synapse/res/templates/sso_auth_success.html @@ -1,29 +1,25 @@ - - - - - Authentication successful - - - - - - -
-

Thank you

-

- Now we know it’s you, you can close this window and return to the - application. -

-
- {% include "sso_footer.html" without context %} - - +{% block title %}Authentication successful{% endblock %} + +{% block header %} + + +{% endblock %} + +{% block body %} +
+

Thank you

+

+ Now we know it’s you, you can close this window and return to the + application. +

+
+{% include "sso_footer.html" without context %} +{% endblock %} diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html index 19992ff2ad..e394a92623 100644 --- a/synapse/res/templates/sso_error.html +++ b/synapse/res/templates/sso_error.html @@ -1,19 +1,19 @@ - - - - - Authentication failed - - - - - +{% block header %} +{% if error == "unauthorised" %} + +{% endif %} +{% endblock %} + +{% block body %} +
{# If an error of unauthorised is returned it means we have actively rejected their login #} {% if error == "unauthorised" %}
@@ -66,5 +66,5 @@ } {% endif %} - - +
+{% endblock %} diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html index 56fabfa3d2..a2772ca9ef 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -1,63 +1,59 @@ - - - - - - - Choose identity provider - - - -
-

Log in to {{ server_name }}

-

Choose an identity provider to log in

-
-
- -
- {% include "sso_footer.html" without context %} - - + .providers a { + display: block; + border-radius: 4px; + border: 1px solid #17191C; + padding: 8px; + text-align: center; + text-decoration: none; + color: #17191C; + display: flex; + align-items: center; + font-weight: bold; + } + + .providers a img { + width: 24px; + height: 24px; + } + .providers a span { + flex: 1; + } + +{% endblock %} + +{% block body %} +
+

Log in to {{ server_name }}

+

Choose an identity provider to log in

+
+
+ +
+{% include "sso_footer.html" without context %} +{% endblock %} diff --git a/synapse/res/templates/sso_new_user_consent.html b/synapse/res/templates/sso_new_user_consent.html index 523f64c4fc..126887d26c 100644 --- a/synapse/res/templates/sso_new_user_consent.html +++ b/synapse/res/templates/sso_new_user_consent.html @@ -1,33 +1,29 @@ - - - - - Agree to terms and conditions - - - - - -
-

Your account is nearly ready

-

Agree to the terms to create your account.

-
-
- {% include "sso_partial_profile.html" %} - -
- {% include "sso_footer.html" without context %} - - +{% block header %} + +{% endblock %} + +{% block body %} +
+

Your account is nearly ready

+

Agree to the terms to create your account.

+
+
+ {% include "sso_partial_profile.html" %} + +
+{% include "sso_footer.html" without context %} +{% endblock %} diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html index 1049a9bd92..887ee0d294 100644 --- a/synapse/res/templates/sso_redirect_confirm.html +++ b/synapse/res/templates/sso_redirect_confirm.html @@ -1,41 +1,38 @@ - - - - - Continue to your account - - - - - -
-

Continue to your account

-
-
- {% include "sso_partial_profile.html" %} -

Continuing will grant {{ display_url }} access to your account.

- Continue -
- {% include "sso_footer.html" without context %} - - + .confirm-trust { + margin: 34px 0; + color: #8D99A5; + } + .confirm-trust strong { + color: #17191C; + } + + .confirm-trust::before { + content: ""; + background-image: url('data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMTgiIGhlaWdodD0iMTgiIHZpZXdCb3g9IjAgMCAxOCAxOCIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZmlsbC1ydWxlPSJldmVub2RkIiBjbGlwLXJ1bGU9ImV2ZW5vZGQiIGQ9Ik0xNi41IDlDMTYuNSAxMy4xNDIxIDEzLjE0MjEgMTYuNSA5IDE2LjVDNC44NTc4NiAxNi41IDEuNSAxMy4xNDIxIDEuNSA5QzEuNSA0Ljg1Nzg2IDQuODU3ODYgMS41IDkgMS41QzEzLjE0MjEgMS41IDE2LjUgNC44NTc4NiAxNi41IDlaTTcuMjUgOUM3LjI1IDkuNDY1OTYgNy41Njg2OSA5Ljg1NzQ4IDggOS45Njg1VjEyLjM3NUM4IDEyLjkyNzMgOC40NDc3MiAxMy4zNzUgOSAxMy4zNzVIMTAuMTI1QzEwLjY3NzMgMTMuMzc1IDExLjEyNSAxMi45MjczIDExLjEyNSAxMi4zNzVDMTEuMTI1IDExLjgyMjcgMTAuNjc3MyAxMS4zNzUgMTAuMTI1IDExLjM3NUgxMFY5QzEwIDguOTY1NDggOS45OTgyNSA4LjkzMTM3IDkuOTk0ODQgOC44OTc3NkM5Ljk0MzYzIDguMzkzNSA5LjUxNzc3IDggOSA4SDguMjVDNy42OTc3MiA4IDcuMjUgOC40NDc3MiA3LjI1IDlaTTkgNy41QzkuNjIxMzIgNy41IDEwLjEyNSA2Ljk5NjMyIDEwLjEyNSA2LjM3NUMxMC4xMjUgNS43NTM2OCA5LjYyMTMyIDUuMjUgOSA1LjI1QzguMzc4NjggNS4yNSA3Ljg3NSA1Ljc1MzY4IDcuODc1IDYuMzc1QzcuODc1IDYuOTk2MzIgOC4zNzg2OCA3LjUgOSA3LjVaIiBmaWxsPSIjQzFDNkNEIi8+Cjwvc3ZnPgoK'); + background-repeat: no-repeat; + width: 24px; + height: 24px; + display: block; + float: left; + } + +{% endblock %} + +{% block body %} +
+

Continue to your account

+
+
+ {% include "sso_partial_profile.html" %} +

Continuing will grant {{ display_url }} access to your account.

+ Continue +
+{% include "sso_footer.html" without context %} + +{% endblock %} diff --git a/synapse/res/templates/style.css b/synapse/res/templates/style.css new file mode 100644 index 0000000000..097b235ae5 --- /dev/null +++ b/synapse/res/templates/style.css @@ -0,0 +1,29 @@ +html { + height: 100%; +} + +body { + background: #f9fafb; + max-width: 680px; + margin: auto; + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol"; +} + +.mx_Header { + border-bottom: 3px solid #ddd; + margin-bottom: 1rem; + padding-top: 1rem; + padding-bottom: 1rem; + text-align: center; +} + +@media screen and (max-width: 1120px) { + body { + font-size: 20px; + } + + h1 { font-size: 1rem; } + h2 { font-size: .9rem; } + h3 { font-size: .85rem; } + h4 { font-size: .8rem; } +} diff --git a/synapse/res/templates/terms.html b/synapse/res/templates/terms.html index 2081d990ab..977c3d0bc7 100644 --- a/synapse/res/templates/terms.html +++ b/synapse/res/templates/terms.html @@ -1,11 +1,10 @@ - - -Authentication - - +{% block title %}Authentication{% endblock %} + +{% block header %} - - +{% endblock %} + +{% block body %}
{% if error is defined %} @@ -19,5 +18,4 @@
- - +{% endblock %} -- cgit 1.4.1 From b7a7ff6ee39da4981dcfdce61bf8ac4735e3d047 Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 21 Oct 2022 10:46:22 -0700 Subject: Add initial power level event to batch of bulk persisted events when creating a new room. (#14228) --- changelog.d/14228.misc | 1 + synapse/handlers/federation.py | 4 +- synapse/handlers/federation_event.py | 4 +- synapse/handlers/message.py | 14 ++---- synapse/handlers/room.py | 39 ++++----------- synapse/push/bulk_push_rule_evaluator.py | 74 ++++++++++++++++++++++++----- tests/push/test_bulk_push_rule_evaluator.py | 2 +- tests/replication/_base.py | 2 +- 8 files changed, 82 insertions(+), 58 deletions(-) create mode 100644 changelog.d/14228.misc (limited to 'synapse') diff --git a/changelog.d/14228.misc b/changelog.d/14228.misc new file mode 100644 index 0000000000..14fe31a8bc --- /dev/null +++ b/changelog.d/14228.misc @@ -0,0 +1 @@ +Add initial power level event to batch of bulk persisted events when creating a new room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 275a37a575..4fbc79a6cb 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1017,7 +1017,9 @@ class FederationHandler: context = EventContext.for_outlier(self._storage_controllers) - await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] + ) try: await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 06e41b5cc0..7da6316a82 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2171,8 +2171,8 @@ class FederationEventHandler: min_depth, ) else: - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] ) try: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 15b828dd74..468900a07f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1433,17 +1433,9 @@ class EventCreationHandler: a room that has been un-partial stated. """ - for event, context in events_and_context: - # Skip push notification actions for historical messages - # because we don't want to notify people about old history back in time. - # The historical messages also do not have the proper `context.current_state_ids` - # and `state_groups` because they have `prev_events` that aren't persisted yet - # (historical messages persisted in reverse-chronological order). - if not event.internal_metadata.is_historical(): - with opentracing.start_active_span("calculate_push_actions"): - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context - ) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + events_and_context + ) try: # If we're a worker we need to hit out to the master. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 638f54051a..cc1e5c8f97 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1055,9 +1055,6 @@ class RoomCreationHandler: event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} depth = 1 - # the last event sent/persisted to the db - last_sent_event_id: Optional[str] = None - # the most recently created event prev_event: List[str] = [] # a map of event types, state keys -> event_ids. We collect these mappings this as events are @@ -1102,26 +1099,6 @@ class RoomCreationHandler: return new_event, new_context - async def send( - event: EventBase, - context: synapse.events.snapshot.EventContext, - creator: Requester, - ) -> int: - nonlocal last_sent_event_id - - ev = await self.event_creation_handler.handle_new_client_event( - requester=creator, - events_and_context=[(event, context)], - ratelimit=False, - ignore_shadow_ban=True, - ) - - last_sent_event_id = ev.event_id - - # we know it was persisted, so must have a stream ordering - assert ev.internal_metadata.stream_ordering - return ev.internal_metadata.stream_ordering - try: config = self._presets_dict[preset_config] except KeyError: @@ -1135,10 +1112,14 @@ class RoomCreationHandler: ) logger.debug("Sending %s in new room", EventTypes.Member) - await send(creation_event, creation_context, creator) + ev = await self.event_creation_handler.handle_new_client_event( + requester=creator, + events_and_context=[(creation_event, creation_context)], + ratelimit=False, + ignore_shadow_ban=True, + ) + last_sent_event_id = ev.event_id - # Room create event must exist at this point - assert last_sent_event_id is not None member_event_id, _ = await self.room_member_handler.update_membership( creator, creator.user, @@ -1157,6 +1138,7 @@ class RoomCreationHandler: depth += 1 state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id + events_to_send = [] # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) @@ -1165,7 +1147,7 @@ class RoomCreationHandler: EventTypes.PowerLevels, pl_content, False ) current_state_group = power_context._state_group - await send(power_event, power_context, creator) + events_to_send.append((power_event, power_context)) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1214,9 +1196,8 @@ class RoomCreationHandler: False, ) current_state_group = pl_context._state_group - await send(pl_event, pl_context, creator) + events_to_send.append((pl_event, pl_context)) - events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: room_alias_event, room_alias_context = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index a75386f6a0..d7795a9080 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -165,8 +165,21 @@ class BulkPushRuleEvaluator: return rules_by_user async def _get_power_levels_and_sender_level( - self, event: EventBase, context: EventContext + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], ) -> Tuple[dict, Optional[int]]: + """ + Given an event and an event context, get the power level event relevant to the event + and the power level of the sender of the event. + Args: + event: event to check + context: context of event to check + event_id_to_event: a mapping of event_id to event for a set of events being + batch persisted. This is needed as the sought-after power level event may + be in this batch rather than the DB + """ # There are no power levels and sender levels possible to get from outlier if event.internal_metadata.is_outlier(): return {}, None @@ -177,15 +190,26 @@ class BulkPushRuleEvaluator: ) pl_event_id = prev_state_ids.get(POWER_KEY) + # fastpath: if there's a power level event, that's all we need, and + # not having a power level event is an extreme edge case if pl_event_id: - # fastpath: if there's a power level event, that's all we need, and - # not having a power level event is an extreme edge case - auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} + # Get the power level event from the batch, or fall back to the database. + pl_event = event_id_to_event.get(pl_event_id) + if pl_event: + auth_events = {POWER_KEY: pl_event} + else: + auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} else: auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events_dict = await self.store.get_events(auth_events_ids) + # Some needed auth events might be in the batch, combine them with those + # fetched from the database. + for auth_event_id in auth_events_ids: + auth_event = event_id_to_event.get(auth_event_id) + if auth_event: + auth_events_dict[auth_event_id] = auth_event auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -194,16 +218,38 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - @measure_func("action_for_event_by_user") - async def action_for_event_by_user( - self, event: EventBase, context: EventContext + async def action_for_events_by_user( + self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: - """Given an event and context, evaluate the push rules, check if the message - should increment the unread count, and insert the results into the - event_push_actions_staging table. + """Given a list of events and their associated contexts, evaluate the push rules + for each event, check if the message should increment the unread count, and + insert the results into the event_push_actions_staging table. """ - if not event.internal_metadata.is_notifiable(): - # Push rules for events that aren't notifiable can't be processed by this + # For batched events the power level events may not have been persisted yet, + # so we pass in the batched events. Thus if the event cannot be found in the + # database we can check in the batch. + event_id_to_event = {e.event_id: e for e, _ in events_and_context} + for event, context in events_and_context: + await self._action_for_event_by_user(event, context, event_id_to_event) + + @measure_func("action_for_event_by_user") + async def _action_for_event_by_user( + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], + ) -> None: + + if ( + not event.internal_metadata.is_notifiable() + or event.internal_metadata.is_historical() + ): + # Push rules for events that aren't notifiable can't be processed by this and + # we want to skip push notification actions for historical messages + # because we don't want to notify people about old history back in time. + # The historical messages also do not have the proper `context.current_state_ids` + # and `state_groups` because they have `prev_events` that aren't persisted yet + # (historical messages persisted in reverse-chronological order). return # Disable counting as unread unless the experimental configuration is @@ -223,7 +269,9 @@ class BulkPushRuleEvaluator: ( power_levels, sender_power_level, - ) = await self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level( + event, context, event_id_to_event + ) # Find the event's thread ID. relation = relation_from_event(event) diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 675d7df2ac..594e7937a8 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -71,4 +71,4 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise - self.get_success(bulk_evaluator.action_for_event_by_user(event, context)) + self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index ce53f808db..121f3d8d65 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -371,7 +371,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config=worker_hs.config.server.listeners[0], resource=resource, server_version_string="1", - max_request_body_size=4096, + max_request_body_size=8192, reactor=self.reactor, ) -- cgit 1.4.1 From 1469fed0e39d31a063e8a54c2ea027774eec6acb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 24 Oct 2022 10:45:10 +0100 Subject: Add debugging to help diagnose lost device-list-update (#14268) --- changelog.d/14268.misc | 1 + synapse/storage/databases/main/devices.py | 54 +++++++++++++++++++++---------- 2 files changed, 38 insertions(+), 17 deletions(-) create mode 100644 changelog.d/14268.misc (limited to 'synapse') diff --git a/changelog.d/14268.misc b/changelog.d/14268.misc new file mode 100644 index 0000000000..894b1e1d4c --- /dev/null +++ b/changelog.d/14268.misc @@ -0,0 +1 @@ +Add debugging to help diagnose lost device-list-update. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 830b076a32..979dd4e17e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -274,6 +274,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): destination, int(from_stream_id) ) if not has_changed: + # debugging for https://github.com/matrix-org/synapse/issues/14251 + issue_8631_logger.debug( + "%s: no change between %i and %i", + destination, + from_stream_id, + now_stream_id, + ) return now_stream_id, [] updates = await self.db_pool.runInteraction( @@ -1848,7 +1855,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn: LoggingTransaction, user_id: str, - device_ids: Iterable[str], + device_id: str, hosts: Collection[str], stream_ids: List[int], context: Optional[Dict[str, str]], @@ -1864,6 +1871,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): stream_id_iterator = iter(stream_ids) encoded_context = json_encoder.encode(context) + mark_sent = not self.hs.is_mine_id(user_id) + + values = [ + ( + destination, + next(stream_id_iterator), + user_id, + device_id, + mark_sent, + now, + encoded_context if whitelisted_homeserver(destination) else "{}", + ) + for destination in hosts + ] + self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", @@ -1876,23 +1898,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "ts", "opentracing_context", ), - values=[ - ( - destination, - next(stream_id_iterator), - user_id, - device_id, - not self.hs.is_mine_id( - user_id - ), # We only need to send out update for *our* users - now, - encoded_context if whitelisted_homeserver(destination) else "{}", - ) - for destination in hosts - for device_id in device_ids - ], + values=values, ) + # debugging for https://github.com/matrix-org/synapse/issues/14251 + if issue_8631_logger.isEnabledFor(logging.DEBUG): + issue_8631_logger.debug( + "Recorded outbound pokes for %s:%s with device stream ids %s", + user_id, + device_id, + { + stream_id: destination + for (destination, stream_id, _, _, _, _, _) in values + }, + ) + def _add_device_outbound_room_poke_txn( self, txn: LoggingTransaction, @@ -1997,7 +2017,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self._add_device_outbound_poke_to_stream_txn( txn, user_id=user_id, - device_ids=[device_id], + device_id=device_id, hosts=hosts, stream_ids=stream_ids, context=context, -- cgit 1.4.1 From 09b588854e3a6abc4ea2eaa68bb0345f23be5ce8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 24 Oct 2022 13:05:14 +0100 Subject: Fix `TypeError: 'dict_keys' object is not reversible` (#14280) --- changelog.d/14280.bugfix | 1 + synapse/federation/sender/__init__.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14280.bugfix (limited to 'synapse') diff --git a/changelog.d/14280.bugfix b/changelog.d/14280.bugfix new file mode 100644 index 0000000000..c546d2be48 --- /dev/null +++ b/changelog.d/14280.bugfix @@ -0,0 +1 @@ +Fix broken outbound federation when using Python 3.7. Broke in v1.70.0rc1. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 774ecd81b6..3ad483efe0 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -536,8 +536,7 @@ class FederationSender(AbstractFederationSender): if event_entries: now = self.clock.time_msec() - last_id = next(reversed(event_ids)) - ts = event_to_received_ts[last_id] + ts = max(t for t in event_to_received_ts.values() if t) assert ts is not None synapse.metrics.event_processing_lag.labels( -- cgit 1.4.1 From 19c0e55ef7742d67cff1cb6fb7c3e862b86ea788 Mon Sep 17 00:00:00 2001 From: Ryan Miguel <1818590+renegaderyu@users.noreply.github.com> Date: Mon, 24 Oct 2022 08:55:06 -0700 Subject: Return NOT_JSON if decode fails and defer set_timeline_upper_limit ca… (#14262) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Return NOT_JSON if decode fails and defer set_timeline_upper_limit call until after check_valid_filter. Fixes #13661. Signed-off-by: Ryan Miguel . * Reword changelog --- changelog.d/14262.misc | 1 + synapse/rest/client/sync.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14262.misc (limited to 'synapse') diff --git a/changelog.d/14262.misc b/changelog.d/14262.misc new file mode 100644 index 0000000000..c1d23bc67d --- /dev/null +++ b/changelog.d/14262.misc @@ -0,0 +1 @@ +Provide a specific error code when a `/sync` request provides a filter which doesn't represent a JSON object. diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 8a16459105..f2013faeb2 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -146,12 +146,12 @@ class SyncRestServlet(RestServlet): elif filter_id.startswith("{"): try: filter_object = json_decoder.decode(filter_id) - set_timeline_upper_limit( - filter_object, self.hs.config.server.filter_timeline_limit - ) except Exception: - raise SynapseError(400, "Invalid filter JSON") + raise SynapseError(400, "Invalid filter JSON", errcode=Codes.NOT_JSON) self.filtering.check_valid_filter(filter_object) + set_timeline_upper_limit( + filter_object, self.hs.config.server.filter_timeline_limit + ) filter_collection = FilterCollection(self.hs, filter_object) else: try: -- cgit 1.4.1 From 581b37b5d6c1c9430108930a4fe409cf3f86332f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 Oct 2022 12:07:16 -0400 Subject: Revert behavior change for bundling edits of non-message events (#14283) --- changelog.d/14283.bugfix | 1 + synapse/storage/databases/main/relations.py | 11 +++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14283.bugfix (limited to 'synapse') diff --git a/changelog.d/14283.bugfix b/changelog.d/14283.bugfix new file mode 100644 index 0000000000..a80a8c0361 --- /dev/null +++ b/changelog.d/14283.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0rc1 where edits to non-message events were aggregated by the homeserver. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 1de62ee9df..c022510e76 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -484,11 +484,12 @@ class RelationsWorkerStore(SQLBaseStore): the event will map to None. """ - # We only allow edits for events that have the same sender and event type. - # We can't assert these things during regular event auth so we have to do - # the checks post hoc. + # We only allow edits for `m.room.message` events that have the same sender + # and event type. We can't assert these things during regular event auth so + # we have to do the checks post hoc. - # Fetches latest edit that has the same type and sender as the original. + # Fetches latest edit that has the same type and sender as the + # original, and is an `m.room.message`. if isinstance(self.database_engine, PostgresEngine): # The `DISTINCT ON` clause will pick the *first* row it encounters, # so ordering by origin server ts + event ID desc will ensure we get @@ -504,6 +505,7 @@ class RelationsWorkerStore(SQLBaseStore): WHERE %s AND relation_type = ? + AND edit.type = 'm.room.message' ORDER by original.event_id DESC, edit.origin_server_ts DESC, edit.event_id DESC """ else: @@ -522,6 +524,7 @@ class RelationsWorkerStore(SQLBaseStore): WHERE %s AND relation_type = ? + AND edit.type = 'm.room.message' ORDER by edit.origin_server_ts, edit.event_id """ -- cgit 1.4.1 From 8c94dd3a277d4e11192f98a9ca32cb6638606b66 Mon Sep 17 00:00:00 2001 From: asymmetric Date: Tue, 25 Oct 2022 11:22:55 +0200 Subject: Enable WAL for SQLite (#13897) Signed-off-by: Lorenzo Manacorda --- changelog.d/13897.feature | 1 + synapse/storage/engines/sqlite.py | 4 ++++ 2 files changed, 5 insertions(+) create mode 100644 changelog.d/13897.feature (limited to 'synapse') diff --git a/changelog.d/13897.feature b/changelog.d/13897.feature new file mode 100644 index 0000000000..d46fdf9fa5 --- /dev/null +++ b/changelog.d/13897.feature @@ -0,0 +1 @@ +Enable Write-Ahead Logging for SQLite installs. Contributed by [asymmetric](https://github.com/asymmetric). diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index faa574dbfd..14260442b6 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -88,6 +88,10 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): db_conn.create_function("rank", 1, _rank) db_conn.execute("PRAGMA foreign_keys = ON;") + + # Enable WAL. + # see https://www.sqlite.org/wal.html + db_conn.execute("PRAGMA journal_mode = WAL;") db_conn.commit() def is_deadlock(self, error: Exception) -> bool: -- cgit 1.4.1 From c9dffd5b330553c5803784be5bc0e2479fab79b0 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 25 Oct 2022 11:39:25 +0100 Subject: Remove unused `@lru_cache` decorator (#13595) * Remove unused `@lru_cache` decorator Spotted this working on something else. Co-authored-by: David Robertson --- changelog.d/13595.misc | 1 + synapse/util/caches/descriptors.py | 104 ---------------------------------- tests/util/caches/test_descriptors.py | 40 ++----------- 3 files changed, 5 insertions(+), 140 deletions(-) create mode 100644 changelog.d/13595.misc (limited to 'synapse') diff --git a/changelog.d/13595.misc b/changelog.d/13595.misc new file mode 100644 index 0000000000..71959a6ee7 --- /dev/null +++ b/changelog.d/13595.misc @@ -0,0 +1 @@ +Remove unused `@lru_cache` decorator. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index b3c748ef44..75428d19ba 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum import functools import inspect import logging @@ -146,109 +145,6 @@ class _CacheDescriptorBase: ) -class _LruCachedFunction(Generic[F]): - cache: LruCache[CacheKey, Any] - __call__: F - - -def lru_cache( - *, max_entries: int = 1000, cache_context: bool = False -) -> Callable[[F], _LruCachedFunction[F]]: - """A method decorator that applies a memoizing cache around the function. - - This is more-or-less a drop-in equivalent to functools.lru_cache, although note - that the signature is slightly different. - - The main differences with functools.lru_cache are: - (a) the size of the cache can be controlled via the cache_factor mechanism - (b) the wrapped function can request a "cache_context" which provides a - callback mechanism to indicate that the result is no longer valid - (c) prometheus metrics are exposed automatically. - - The function should take zero or more arguments, which are used as the key for the - cache. Single-argument functions use that argument as the cache key; otherwise the - arguments are built into a tuple. - - Cached functions can be "chained" (i.e. a cached function can call other cached - functions and get appropriately invalidated when they called caches are - invalidated) by adding a special "cache_context" argument to the function - and passing that as a kwarg to all caches called. For example: - - @lru_cache(cache_context=True) - def foo(self, key, cache_context): - r1 = self.bar1(key, on_invalidate=cache_context.invalidate) - r2 = self.bar2(key, on_invalidate=cache_context.invalidate) - return r1 + r2 - - The wrapped function also has a 'cache' property which offers direct access to the - underlying LruCache. - """ - - def func(orig: F) -> _LruCachedFunction[F]: - desc = LruCacheDescriptor( - orig, - max_entries=max_entries, - cache_context=cache_context, - ) - return cast(_LruCachedFunction[F], desc) - - return func - - -class LruCacheDescriptor(_CacheDescriptorBase): - """Helper for @lru_cache""" - - class _Sentinel(enum.Enum): - sentinel = object() - - def __init__( - self, - orig: Callable[..., Any], - max_entries: int = 1000, - cache_context: bool = False, - ): - super().__init__( - orig, num_args=None, uncached_args=None, cache_context=cache_context - ) - self.max_entries = max_entries - - def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: - cache: LruCache[CacheKey, Any] = LruCache( - cache_name=self.name, - max_size=self.max_entries, - ) - - get_cache_key = self.cache_key_builder - sentinel = LruCacheDescriptor._Sentinel.sentinel - - @functools.wraps(self.orig) - def _wrapped(*args: Any, **kwargs: Any) -> Any: - invalidate_callback = kwargs.pop("on_invalidate", None) - callbacks = (invalidate_callback,) if invalidate_callback else () - - cache_key = get_cache_key(args, kwargs) - - ret = cache.get(cache_key, default=sentinel, callbacks=callbacks) - if ret != sentinel: - return ret - - # Add our own `cache_context` to argument list if the wrapped function - # has asked for one - if self.add_cache_context: - kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) - - ret2 = self.orig(obj, *args, **kwargs) - cache.set(cache_key, ret2, callbacks=callbacks) - - return ret2 - - wrapped = cast(CachedFunction, _wrapped) - wrapped.cache = cache - obj.__dict__[self.name] = wrapped - - return wrapped - - class DeferredCacheDescriptor(_CacheDescriptorBase): """A method decorator that applies a memoizing cache around the function. diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 78fd7b6961..43475a307f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -28,7 +28,7 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached, cachedList, lru_cache +from synapse.util.caches.descriptors import cached, cachedList from tests import unittest from tests.test_utils import get_awaitable_result @@ -36,38 +36,6 @@ from tests.test_utils import get_awaitable_result logger = logging.getLogger(__name__) -class LruCacheDecoratorTestCase(unittest.TestCase): - def test_base(self): - class Cls: - def __init__(self): - self.mock = mock.Mock() - - @lru_cache() - def fn(self, arg1, arg2): - return self.mock(arg1, arg2) - - obj = Cls() - obj.mock.return_value = "fish" - r = obj.fn(1, 2) - self.assertEqual(r, "fish") - obj.mock.assert_called_once_with(1, 2) - obj.mock.reset_mock() - - # a call with different params should call the mock again - obj.mock.return_value = "chips" - r = obj.fn(1, 3) - self.assertEqual(r, "chips") - obj.mock.assert_called_once_with(1, 3) - obj.mock.reset_mock() - - # the two values should now be cached - r = obj.fn(1, 2) - self.assertEqual(r, "fish") - r = obj.fn(1, 3) - self.assertEqual(r, "chips") - obj.mock.assert_not_called() - - def run_on_reactor(): d = defer.Deferred() reactor.callLater(0, d.callback, 0) @@ -478,10 +446,10 @@ class DescriptorTestCase(unittest.TestCase): @cached(cache_context=True) async def func2(self, key, cache_context): - return self.func3(key, on_invalidate=cache_context.invalidate) + return await self.func3(key, on_invalidate=cache_context.invalidate) - @lru_cache(cache_context=True) - def func3(self, key, cache_context): + @cached(cache_context=True) + async def func3(self, key, cache_context): self.invalidate = cache_context.invalidate return 42 -- cgit 1.4.1 From 2d0ba3f89aaf9545d81c4027500e543ec70b68a6 Mon Sep 17 00:00:00 2001 From: "DeepBlueV7.X" Date: Tue, 25 Oct 2022 13:38:01 +0000 Subject: Implementation for MSC3664: Pushrules for relations (#11804) --- changelog.d/11804.feature | 1 + rust/src/push/base_rules.rs | 17 +++ rust/src/push/evaluator.rs | 99 ++++++++++++- rust/src/push/mod.rs | 61 ++++++-- stubs/synapse/synapse_rust/push.pyi | 6 +- synapse/config/experimental.py | 3 + synapse/push/bulk_push_rule_evaluator.py | 49 ++++++- synapse/rest/client/capabilities.py | 5 + synapse/storage/databases/main/push_rule.py | 15 +- tests/push/test_push_rule_evaluator.py | 215 +++++++++++++++++++++++++++- 10 files changed, 454 insertions(+), 17 deletions(-) create mode 100644 changelog.d/11804.feature (limited to 'synapse') diff --git a/changelog.d/11804.feature b/changelog.d/11804.feature new file mode 100644 index 0000000000..6420393541 --- /dev/null +++ b/changelog.d/11804.feature @@ -0,0 +1 @@ +Implement [MSC3664](https://github.com/matrix-org/matrix-doc/pull/3664). Contributed by Nico. diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 63240cacfc..49802fa4eb 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -25,6 +25,7 @@ use crate::push::Action; use crate::push::Condition; use crate::push::EventMatchCondition; use crate::push::PushRule; +use crate::push::RelatedEventMatchCondition; use crate::push::SetTweak; use crate::push::TweakValue; @@ -114,6 +115,22 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed("global/override/.im.nheko.msc3664.reply"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelatedEventMatch( + RelatedEventMatchCondition { + key: Some(Cow::Borrowed("sender")), + pattern: None, + pattern_type: Some(Cow::Borrowed("user_id")), + rel_type: Cow::Borrowed("m.in_reply_to"), + include_fallbacks: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"), priority_class: 5, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 0365dd01dc..cedd42c54d 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -23,6 +23,7 @@ use regex::Regex; use super::{ utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType}, Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition, + RelatedEventMatchCondition, }; lazy_static! { @@ -49,6 +50,13 @@ pub struct PushRuleEvaluator { /// The power level of the sender of the event, or None if event is an /// outlier. sender_power_level: Option, + + /// The related events, indexed by relation type. Flattened in the same manner as + /// `flattened_keys`. + related_events_flattened: BTreeMap>, + + /// If msc3664, push rules for related events, is enabled. + related_event_match_enabled: bool, } #[pymethods] @@ -60,6 +68,8 @@ impl PushRuleEvaluator { room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, + related_events_flattened: BTreeMap>, + related_event_match_enabled: bool, ) -> Result { let body = flattened_keys .get("content.body") @@ -72,6 +82,8 @@ impl PushRuleEvaluator { room_member_count, notification_power_levels, sender_power_level, + related_events_flattened, + related_event_match_enabled, }) } @@ -156,6 +168,9 @@ impl PushRuleEvaluator { KnownCondition::EventMatch(event_match) => { self.match_event_match(event_match, user_id)? } + KnownCondition::RelatedEventMatch(event_match) => { + self.match_related_event_match(event_match, user_id)? + } KnownCondition::ContainsDisplayName => { if let Some(dn) = display_name { if !dn.is_empty() { @@ -239,6 +254,79 @@ impl PushRuleEvaluator { compiled_pattern.is_match(haystack) } + /// Evaluates a `related_event_match` condition. (MSC3664) + fn match_related_event_match( + &self, + event_match: &RelatedEventMatchCondition, + user_id: Option<&str>, + ) -> Result { + // First check if related event matching is enabled... + if !self.related_event_match_enabled { + return Ok(false); + } + + // get the related event, fail if there is none. + let event = if let Some(event) = self.related_events_flattened.get(&*event_match.rel_type) { + event + } else { + return Ok(false); + }; + + // If we are not matching fallbacks, don't match if our special key indicating this is a + // fallback relation is not present. + if !event_match.include_fallbacks.unwrap_or(false) + && event.contains_key("im.vector.is_falling_back") + { + return Ok(false); + } + + // if we have no key, accept the event as matching, if it existed without matching any + // fields. + let key = if let Some(key) = &event_match.key { + key + } else { + return Ok(true); + }; + + let pattern = if let Some(pattern) = &event_match.pattern { + pattern + } else if let Some(pattern_type) = &event_match.pattern_type { + // The `pattern_type` can either be "user_id" or "user_localpart", + // either way if we don't have a `user_id` then the condition can't + // match. + let user_id = if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + }; + + match &**pattern_type { + "user_id" => user_id, + "user_localpart" => get_localpart_from_id(user_id)?, + _ => return Ok(false), + } + } else { + return Ok(false); + }; + + let haystack = if let Some(haystack) = event.get(&**key) { + haystack + } else { + return Ok(false); + }; + + // For the content.body we match against "words", but for everything + // else we match against the entire value. + let match_type = if key == "content.body" { + GlobMatchType::Word + } else { + GlobMatchType::Whole + }; + + let mut compiled_pattern = get_glob_matcher(pattern, match_type)?; + compiled_pattern.is_match(haystack) + } + /// Match the member count against an 'is' condition /// The `is` condition can be things like '>2', '==3' or even just '4'. fn match_member_count(&self, is: &str) -> Result { @@ -267,8 +355,15 @@ impl PushRuleEvaluator { fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); - let evaluator = - PushRuleEvaluator::py_new(flattened_keys, 10, Some(0), BTreeMap::new()).unwrap(); + let evaluator = PushRuleEvaluator::py_new( + flattened_keys, + 10, + Some(0), + BTreeMap::new(), + BTreeMap::new(), + true, + ) + .unwrap(); let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); assert_eq!(result.len(), 3); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 0dabfab8b8..d57800aa4a 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -267,6 +267,8 @@ pub enum Condition { #[serde(tag = "kind")] pub enum KnownCondition { EventMatch(EventMatchCondition), + #[serde(rename = "im.nheko.msc3664.related_event_match")] + RelatedEventMatch(RelatedEventMatchCondition), ContainsDisplayName, RoomMemberCount { #[serde(skip_serializing_if = "Option::is_none")] @@ -299,6 +301,20 @@ pub struct EventMatchCondition { pub pattern_type: Option>, } +/// The body of a [`Condition::RelatedEventMatch`] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RelatedEventMatchCondition { + #[serde(skip_serializing_if = "Option::is_none")] + pub key: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub pattern: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub pattern_type: Option>, + pub rel_type: Cow<'static, str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub include_fallbacks: Option, +} + /// The collection of push rules for a user. #[derive(Debug, Clone, Default)] #[pyclass(frozen)] @@ -391,15 +407,21 @@ impl PushRules { pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, + msc3664_enabled: bool, } #[pymethods] impl FilteredPushRules { #[new] - pub fn py_new(push_rules: PushRules, enabled_map: BTreeMap) -> Self { + pub fn py_new( + push_rules: PushRules, + enabled_map: BTreeMap, + msc3664_enabled: bool, + ) -> Self { Self { push_rules, enabled_map, + msc3664_enabled, } } @@ -414,13 +436,25 @@ impl FilteredPushRules { /// Iterates over all the rules and their enabled state, including base /// rules, in the order they should be executed in. fn iter(&self) -> impl Iterator { - self.push_rules.iter().map(|r| { - let enabled = *self - .enabled_map - .get(&*r.rule_id) - .unwrap_or(&r.default_enabled); - (r, enabled) - }) + self.push_rules + .iter() + .filter(|rule| { + // Ignore disabled experimental push rules + if !self.msc3664_enabled + && rule.rule_id == "global/override/.im.nheko.msc3664.reply" + { + return false; + } + + true + }) + .map(|r| { + let enabled = *self + .enabled_map + .get(&*r.rule_id) + .unwrap_or(&r.default_enabled); + (r, enabled) + }) } } @@ -446,6 +480,17 @@ fn test_deserialize_condition() { let _: Condition = serde_json::from_str(json).unwrap(); } +#[test] +fn test_deserialize_unstable_msc3664_condition() { + let json = r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern":"coffee","rel_type":"m.in_reply_to"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::RelatedEventMatch(_)) + )); +} + #[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index f2a61df660..f3b6d6c933 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -25,7 +25,9 @@ class PushRules: def rules(self) -> Collection[PushRule]: ... class FilteredPushRules: - def __init__(self, push_rules: PushRules, enabled_map: Dict[str, bool]): ... + def __init__( + self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3664_enabled: bool + ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def get_base_rule_ids() -> Collection[str]: ... @@ -37,6 +39,8 @@ class PushRuleEvaluator: room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], + related_events_flattened: Mapping[str, Mapping[str, str]], + related_event_match_enabled: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 4009add01d..d9bdd66d55 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -98,6 +98,9 @@ class ExperimentalConfig(Config): # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) + # MSC3664: Pushrules to match on related events + self.msc3664_enabled: bool = experimental.get("msc3664_enabled", False) + # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d7795a9080..75b7e126ca 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -45,7 +45,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) - push_rules_invalidation_counter = Counter( "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "" ) @@ -107,6 +106,8 @@ class BulkPushRuleEvaluator: self.clock = hs.get_clock() self._event_auth_handler = hs.get_event_auth_handler() + self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled + self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", @@ -218,6 +219,48 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level + async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]: + """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation + + Returns: + Mapping of relation type to flattened events. + """ + related_events: Dict[str, Dict[str, str]] = {} + if self._related_event_match_enabled: + related_event_id = event.content.get("m.relates_to", {}).get("event_id") + relation_type = event.content.get("m.relates_to", {}).get("rel_type") + if related_event_id is not None and relation_type is not None: + related_event = await self.store.get_event( + related_event_id, allow_none=True + ) + if related_event is not None: + related_events[relation_type] = _flatten_dict(related_event) + + reply_event_id = ( + event.content.get("m.relates_to", {}) + .get("m.in_reply_to", {}) + .get("event_id") + ) + + # convert replies to pseudo relations + if reply_event_id is not None: + related_event = await self.store.get_event( + reply_event_id, allow_none=True + ) + + if related_event is not None: + related_events["m.in_reply_to"] = _flatten_dict(related_event) + + # indicate that this is from a fallback relation. + if relation_type == "m.thread" and event.content.get( + "m.relates_to", {} + ).get("is_falling_back", False): + related_events["m.in_reply_to"][ + "im.vector.is_falling_back" + ] = "" + + return related_events + async def action_for_events_by_user( self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: @@ -286,6 +329,8 @@ class BulkPushRuleEvaluator: # the parent is part of a thread. thread_id = await self.store.get_thread_id(relation.parent_id) + related_events = await self._related_events(event) + # It's possible that old room versions have non-integer power levels (floats or # strings). Workaround this by explicitly converting to int. notification_levels = power_levels.get("notifications", {}) @@ -298,6 +343,8 @@ class BulkPushRuleEvaluator: room_member_count, sender_power_level, notification_levels, + related_events, + self._related_event_match_enabled, ) users = rules_by_user.keys() diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 4237071c61..e84dde31b1 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -77,6 +77,11 @@ class CapabilitiesRestServlet(RestServlet): "enabled": True, } + if self.config.experimental.msc3664_enabled: + response["capabilities"]["im.nheko.msc3664.related_event_match"] = { + "enabled": self.config.experimental.msc3664_enabled, + } + return HTTPStatus.OK, response diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 51416b2236..b6c15f29f8 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -29,6 +29,7 @@ from typing import ( ) from synapse.api.errors import StoreError +from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -62,7 +63,9 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], enabled_map: Dict[str, bool] + rawrules: List[JsonDict], + enabled_map: Dict[str, bool], + experimental_config: ExperimentalConfig, ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. @@ -80,7 +83,9 @@ def _load_rules( push_rules = PushRules(ruleslist) - filtered_rules = FilteredPushRules(push_rules, enabled_map) + filtered_rules = FilteredPushRules( + push_rules, enabled_map, msc3664_enabled=experimental_config.msc3664_enabled + ) return filtered_rules @@ -160,7 +165,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map) + return _load_rules(rows, enabled_map, self.hs.config.experimental) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -219,7 +224,9 @@ class PushRulesWorkerStore( results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): - results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) + results[user_id] = _load_rules( + rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental + ) return results diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index decf619466..fe7c145840 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -38,7 +38,9 @@ from tests.test_utils.event_injection import create_event, inject_member_event class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator: + def _get_evaluator( + self, content: JsonDict, related_events=None + ) -> PushRuleEvaluator: event = FrozenEvent( { "event_id": "$event_id", @@ -58,6 +60,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count, sender_power_level, power_levels.get("notifications", {}), + {} if related_events is None else related_events, + True, ) def test_display_name(self) -> None: @@ -292,6 +296,215 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): {"sound": "default", "highlight": True}, ) + def test_related_event_match(self): + evaluator = self._get_evaluator( + { + "m.relates_to": { + "event_id": "$parent_event_id", + "key": "😀", + "rel_type": "m.annotation", + "m.in_reply_to": { + "event_id": "$parent_event_id", + }, + } + }, + { + "m.in_reply_to": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + "m.annotation": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + }, + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@user:test", + }, + "@other_user:test", + "display_name", + ) + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.annotation", + "pattern": "@other_user:test", + }, + "@other_user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.replace", + }, + "@other_user:test", + "display_name", + ) + ) + + def test_related_event_match_with_fallback(self): + evaluator = self._get_evaluator( + { + "m.relates_to": { + "event_id": "$parent_event_id", + "key": "😀", + "rel_type": "m.thread", + "is_falling_back": True, + "m.in_reply_to": { + "event_id": "$parent_event_id", + }, + } + }, + { + "m.in_reply_to": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + "im.vector.is_falling_back": "", + }, + "m.thread": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + }, + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + "include_fallbacks": True, + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + "include_fallbacks": False, + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + + def test_related_event_match_no_related_event(self): + evaluator = self._get_evaluator( + {"msgtype": "m.text", "body": "Message without related event"} + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator""" -- cgit 1.4.1 From 9192d74b0bf2f87b00d3e106a18baa9ce27acda1 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Oct 2022 16:25:02 +0200 Subject: Refactor OIDC tests to better mimic an actual OIDC provider. (#13910) This implements a fake OIDC server, which intercepts calls to the HTTP client. Improves accuracy of tests by covering more internal methods. One particular example was the ID token validation, which previously mocked. This uncovered an incorrect dependency: Synapse actually requires at least authlib 0.15.1, not 0.14.0. --- changelog.d/13910.misc | 1 + pyproject.toml | 2 +- synapse/handlers/oidc.py | 15 +- tests/federation/test_federation_client.py | 36 +- tests/handlers/test_oidc.py | 580 +++++++++++++---------------- tests/rest/client/test_auth.py | 32 +- tests/rest/client/test_login.py | 40 +- tests/rest/client/utils.py | 136 +++---- tests/test_utils/__init__.py | 40 +- tests/test_utils/oidc.py | 325 ++++++++++++++++ 10 files changed, 747 insertions(+), 460 deletions(-) create mode 100644 changelog.d/13910.misc create mode 100644 tests/test_utils/oidc.py (limited to 'synapse') diff --git a/changelog.d/13910.misc b/changelog.d/13910.misc new file mode 100644 index 0000000000..e906952aab --- /dev/null +++ b/changelog.d/13910.misc @@ -0,0 +1 @@ +Refactor OIDC tests to better mimic an actual OIDC provider. diff --git a/pyproject.toml b/pyproject.toml index 6ebac41ed1..7e0feb75aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,7 +192,7 @@ psycopg2 = { version = ">=2.8", markers = "platform_python_implementation != 'Py psycopg2cffi = { version = ">=2.8", markers = "platform_python_implementation == 'PyPy'", optional = true } psycopg2cffi-compat = { version = "==1.1", markers = "platform_python_implementation == 'PyPy'", optional = true } pysaml2 = { version = ">=4.5.0", optional = true } -authlib = { version = ">=0.14.0", optional = true } +authlib = { version = ">=0.15.1", optional = true } # systemd-python is necessary for logging to the systemd journal via # `systemd.journal.JournalHandler`, as is documented in # `contrib/systemd/log_config.yaml`. diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index d7a8226900..9759daf043 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -275,6 +275,7 @@ class OidcProvider: provider: OidcProviderConfig, ): self._store = hs.get_datastores().main + self._clock = hs.get_clock() self._macaroon_generaton = macaroon_generator @@ -673,6 +674,13 @@ class OidcProvider: Returns: The decoded claims in the ID token. """ + id_token = token.get("id_token") + logger.debug("Attempting to decode JWT id_token %r", id_token) + + # That has been theoritically been checked by the caller, so even though + # assertion are not enabled in production, it is mainly here to appease mypy + assert id_token is not None + metadata = await self.load_metadata() claims_params = { "nonce": nonce, @@ -688,9 +696,6 @@ class OidcProvider: claim_options = {"iss": {"values": [metadata["issuer"]]}} - id_token = token["id_token"] - logger.debug("Attempting to decode JWT id_token %r", id_token) - # Try to decode the keys in cache first, then retry by forcing the keys # to be reloaded jwk_set = await self.load_jwks() @@ -715,7 +720,9 @@ class OidcProvider: logger.debug("Decoded id_token JWT %r; validating", claims) - claims.validate(leeway=120) # allows 2 min of clock skew + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew return claims diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index a538215931..51d3bb8fff 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from unittest import mock import twisted.web.client from twisted.internet import defer -from twisted.internet.protocol import Protocol -from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import RoomVersions @@ -26,10 +23,9 @@ from synapse.events import EventBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict from synapse.util import Clock -from tests.test_utils import event_injection +from tests.test_utils import FakeResponse, event_injection from tests.unittest import FederatingHomeserverTestCase @@ -98,8 +94,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "pdus": [ create_event_dict, member_event_dict, @@ -208,8 +204,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "origin": "yet.another.server", "origin_server_ts": 900, "pdus": [ @@ -269,8 +265,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # We expect an outbound request to /backfill, so stub that out self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "origin": "yet.another.server", "origin_server_ts": 900, # Mimic the other server returning our new `pulled_event` @@ -305,21 +301,3 @@ class FederationClientTest(FederatingHomeserverTestCase): # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the # other from "yet.another.server" self.assertEqual(backfill_num_attempts, 2) - - -def _mock_response(resp: JsonDict): - body = json.dumps(resp).encode("utf-8") - - def deliver_body(p: Protocol): - p.dataReceived(body) - p.connectionLost(Failure(twisted.web.client.ResponseDone())) - - response = mock.Mock( - code=200, - phrase=b"OK", - headers=twisted.web.client.Headers({"content-Type": ["application/json"]}), - length=len(body), - deliverBody=deliver_body, - ) - mock.seal(response) - return response diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e6cd3af7b7..5955410524 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os -from typing import Any, Dict +from typing import Any, Dict, Tuple from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse @@ -22,12 +21,15 @@ import pymacaroons from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.sso import MappingException +from synapse.http.site import SynapseRequest from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import UserID from synapse.util import Clock -from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon +from synapse.util.macaroons import get_value_from_macaroon +from synapse.util.stringutils import random_string from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer from tests.unittest import HomeserverTestCase, override_config try: @@ -46,12 +48,6 @@ BASE_URL = "https://synapse/" CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] -AUTHORIZATION_ENDPOINT = ISSUER + "authorize" -TOKEN_ENDPOINT = ISSUER + "token" -USERINFO_ENDPOINT = ISSUER + "userinfo" -WELL_KNOWN = ISSUER + ".well-known/openid-configuration" -JWKS_URI = ISSUER + ".well-known/jwks.json" - # config for common cases DEFAULT_CONFIG = { "enabled": True, @@ -66,9 +62,9 @@ DEFAULT_CONFIG = { EXPLICIT_ENDPOINT_CONFIG = { **DEFAULT_CONFIG, "discover": False, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, + "authorization_endpoint": ISSUER + "authorize", + "token_endpoint": ISSUER + "token", + "jwks_uri": ISSUER + "jwks", } @@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider): } -async def get_json(url: str) -> JsonDict: - # Mock get_json calls to handle jwks & oidc discovery endpoints - if url == WELL_KNOWN: - # Minimal discovery document, as defined in OpenID.Discovery - # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata - return { - "issuer": ISSUER, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, - "userinfo_endpoint": USERINFO_ENDPOINT, - "response_types_supported": ["code"], - "subject_types_supported": ["public"], - "id_token_signing_alg_values_supported": ["RS256"], - } - elif url == JWKS_URI: - return {"keys": []} - - return {} - - def _key_file_path() -> str: """path to a file containing the private half of a test key""" @@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase): return config def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.http_client = Mock(spec=["get_json"]) - self.http_client.get_json.side_effect = get_json - self.http_client.user_agent = b"Synapse Test" + self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER) - hs = self.setup_test_homeserver(proxied_http_client=self.http_client) + hs = self.setup_test_homeserver() + self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) + self.hs_patcher.start() self.handler = hs.get_oidc_handler() self.provider = self.handler._providers["oidc"] @@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase): # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 + auth_handler = hs.get_auth_handler() + # Mock the complete SSO login method. + self.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] + return hs + def tearDown(self) -> None: + self.hs_patcher.stop() + return super().tearDown() + + def reset_mocks(self): + """Reset all the Mocks.""" + self.fake_server.reset_mocks() + self.render_error.reset_mock() + self.complete_sso_login.reset_mock() + def metadata_edit(self, values): """Modify the result that will be returned by the well-known query""" - async def patched_get_json(uri): - res = await get_json(uri) - if uri == WELL_KNOWN: - res.update(values) - return res + metadata = self.fake_server.get_metadata() + metadata.update(values) + return patch.object(self.fake_server, "get_metadata", return_value=metadata) - return patch.object(self.http_client, "get_json", patched_get_json) + def start_authorization( + self, + userinfo: dict, + client_redirect_url: str = "http://client/redirect", + scope: str = "openid", + with_sid: bool = False, + ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]: + """Start an authorization request, and get the callback request back.""" + nonce = random_string(10) + state = random_string(10) + + code, grant = self.fake_server.start_authorization( + userinfo=userinfo, + scope=scope, + client_id=self.provider._client_auth.client_id, + redirect_uri=self.provider._callback_url, + nonce=nonce, + with_sid=with_sid, + ) + session = self._generate_oidc_session_token(state, nonce, client_redirect_url) + return _build_callback_request(code, state, session), grant def assertRenderedError(self, error, error_description=None): self.render_error.assert_called_once() @@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.fake_server.get_metadata_handler.assert_called_once() - self.assertEqual(metadata.issuer, ISSUER) - self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT) - self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT) - self.assertEqual(metadata.jwks_uri, JWKS_URI) - # FIXME: it seems like authlib does not have that defined in its metadata models - # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT) + self.assertEqual(metadata.issuer, self.fake_server.issuer) + self.assertEqual( + metadata.authorization_endpoint, + self.fake_server.authorization_endpoint, + ) + self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint) + self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri) + # It seems like authlib does not have that defined in its metadata models + self.assertEqual( + metadata.get("userinfo_endpoint"), + self.fake_server.userinfo_endpoint, + ) # subsequent calls should be cached - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() - @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_called_once_with(JWKS_URI) - self.assertEqual(jwks, {"keys": []}) + self.fake_server.get_jwks_handler.assert_called_once() + self.assertEqual(jwks, self.fake_server.get_jwks()) # subsequent calls should be cached… - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_jwks_handler.assert_not_called() # …unless forced - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks(force=True)) - self.http_client.get_json.assert_called_once_with(JWKS_URI) + self.fake_server.get_jwks_handler.assert_called_once() - # Throw if the JWKS uri is missing - original = self.provider.load_metadata - - async def patched_load_metadata(): - m = (await original()).copy() - m.update({"jwks_uri": None}) - return m - - with patch.object(self.provider, "load_metadata", patched_load_metadata): + with self.metadata_edit({"jwks_uri": None}): + # If we don't do this, the load_metadata call will throw because of the + # missing jwks_uri + self.provider._user_profile_method = "userinfo_endpoint" + self.get_success(self.provider.load_metadata(force=True)) self.get_failure(self.provider.load_jwks(force=True), RuntimeError) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.provider.handle_redirect_request(req, b"http://client/redirect") ) ) - auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) + auth_endpoint = urlparse(self.fake_server.authorization_endpoint) self.assertEqual(url.scheme, auth_endpoint.scheme) self.assertEqual(url.netloc, auth_endpoint.netloc) @@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase): with self.assertRaises(AttributeError): _ = mapping_provider.get_extra_attributes - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } username = "bar" userinfo = { "sub": "foo", "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - code = "code" - state = "state" - nonce = "nonce" client_redirect_url = "http://client/redirect" - ip_address = "10.0.0.1" - session = self._generate_oidc_session_token(state, nonce, client_redirect_url) - request = _build_callback_request(code, state, session, ip_address=ip_address) - + request, _ = self.start_authorization( + userinfo, client_redirect_url=client_redirect_url + ) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, client_redirect_url, None, new_user=True, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_not_called() + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors + request, _ = self.start_authorization(userinfo) with patch.object( self.provider, "_remote_id_from_userinfo", @@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("mapping_error") # Handle ID token errors - self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") - auth_handler.complete_sso_login.reset_mock() - self.provider._exchange_code.reset_mock() - self.provider._parse_id_token.reset_mock() - self.provider._fetch_userinfo.reset_mock() + self.reset_mocks() # With userinfo fetching self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + # Without the "openid" scope, the FakeProvider does not generate an id_token + request, _ = self.start_authorization(userinfo, scope="") self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_not_called() - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() + self.reset_mocks() + # With an ID token, userinfo fetching and sid in the ID token self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - "id_token": "id_token", - } - id_token = { - "sid": "abcdefgh", - } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment] - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - auth_handler.complete_sso_login.reset_mock() - self.provider._fetch_userinfo.reset_mock() + request, grant = self.start_authorization(userinfo, with_sid=True) + self.assertIsNotNone(grant.sid) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, - auth_provider_session_id=id_token["sid"], + auth_provider_session_id=grant.sid, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() # Handle userinfo fetching error - self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_server.buggy_endpoint(userinfo=True): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") - # Handle code exchange failure - from synapse.handlers.oidc import OidcError - - self.provider._exchange_code = simple_async_mock( # type: ignore[assignment] - raises=OidcError("invalid_request") - ) - self.get_success(self.handler.handle_oidc_callback(request)) - self.assertRenderedError("invalid_request") + request, _ = self.start_authorization(userinfo) + with self.fake_server.buggy_endpoint(token=True): + self.get_success(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("server_error") @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_session(self) -> None: @@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_exchange_code(self) -> None: """Code exchange behaves correctly and handles various error scenarios.""" - token = {"type": "bearer"} - token_json = json.dumps(token).encode("utf-8") - self.http_client.request = simple_async_mock( - return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(ret, token) self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) args = parse_qs(kwargs["data"].decode("utf-8")) self.assertEqual(args["grant_type"], ["authorization_code"]) @@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) # Test error handling - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad Request", - body=b'{"error": "foo", "error_description": "bar"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=400, payload={"error": "foo", "error_description": "bar"} ) from synapse.handlers.oidc import OidcError @@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(exc.value.error_description, "bar") # Internal server error with no JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b"Not JSON", - ) + self.fake_server.post_token_handler.return_value = FakeResponse( + code=500, body=b"Not JSON" ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b'{"error": "internal_server_error"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=500, payload={"error": "internal_server_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad request", - body=b"{}", - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=400, payload={} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, - phrase=b"OK", - body=b'{"error": "some_error"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=200, payload={"error": "some_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "some_error") @@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase): """Test that code exchange works with a JWK client secret.""" from authlib.jose import jwt - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" @@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # the client secret provided to the should be a jwt which can be checked with # the public key @@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_exchange_code_no_auth(self) -> None: """Test that code exchange works with no client secret.""" - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) @@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # check the POSTed data args = parse_qs(kwargs["data"].decode("utf-8")) @@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase): """ Login while using a mapping provider that implements get_extra_attributes. """ - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } userinfo = { "sub": "foo", "username": "foo", "phone": "1234567", } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - - state = "state" - client_redirect_url = "http://client/redirect" - session = self._generate_oidc_session_token( - state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ) - request = _build_callback_request("code", state, session) - + request, _ = self.start_authorization(userinfo) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@foo:test", - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, {"phone": "1234567"}, new_user=True, auth_provider_session_id=None, @@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_user(self) -> None: """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - userinfo: dict = { "sub": "test_user", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Some providers return an integer ID. userinfo = { "sub": 1234, "username": "test_user_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Test if the mxid is already taken store = self.hs.get_datastores().main @@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Mapping provider does not support de-duplicating Matrix IDs", @@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user.to_string(), password_hash=None) ) - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # Map a user via SSO. userinfo = { "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Subsequent calls should map to the same mxid. - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Note that a second SSO user can be mapped to the same Matrix ID. (This # requires a unique sub, but something that maps to the same matrix ID, @@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test1", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register some non-exact matching cases. user2 = UserID.from_string("@TEST_user_2:test") @@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test2", "username": "TEST_USER_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() args = self.assertRenderedError("mapping_error") self.assertTrue( args[2].startswith( @@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user2.to_string(), password_hash=None) ) - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@TEST_USER_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, @@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" - self.get_success( - _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) - ) + userinfo = {"sub": "test2", "username": "föö"} + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: föö") @override_config( @@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_map_userinfo_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) @@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # test_user is already taken, so test_user1 gets registered instead. - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@test_user1:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register all of the potential mxids for a particular OIDC username. self.get_success( @@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) @@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_attribute_requirements(self) -> None: """The required attributes must be met from the OIDC userinfo response.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # userinfo lacking "test": "foobar" attribute should fail. userinfo = { "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": "foobar" attribute should succeed. userinfo = { @@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": "foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_attribute_requirements_contains(self) -> None: """Test that auth succeeds if userinfo attribute CONTAINS required value""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed. userinfo = { "sub": "tester", "username": "tester", "test": ["foobar", "foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase): Test that auth fails if attributes exist but don't match, or are non-string values. """ - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": "not_foobar" attribute should fail userinfo: dict = { "sub": "tester", "username": "tester", "test": "not_foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": ["foo", "bar"] attribute should fail userinfo = { @@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": ["foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": False attribute should fail # this is largely just to ensure we don't crash here @@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": False, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": None attribute should fail # a value of None breaks the OIDC spec, but it's important to not crash here @@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 1 attribute should fail # this is largely just to ensure we don't crash here @@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": 1, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 3.14 attribute should fail # this is largely just to ensure we don't crash here @@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": 3.14, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() def _generate_oidc_session_token( self, @@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return self.handler._macaroon_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( - idp_id="oidc", + idp_id=self.provider.idp_id, nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, @@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) -async def _make_callback_with_userinfo( - hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" -) -> None: - """Mock up an OIDC callback with the given userinfo dict - - We'll pull out the OIDC handler from the homeserver, stub out a couple of methods, - and poke in the userinfo dict as if it were the response to an OIDC userinfo call. - - Args: - hs: the HomeServer impl to send the callback to. - userinfo: the OIDC userinfo dict - client_redirect_url: the URL to redirect to on success. - """ - - handler = hs.get_oidc_handler() - provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment] - provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - - state = "state" - session = handler._macaroon_generator.generate_oidc_session_token( - state=state, - session_data=OidcSessionData( - idp_id="oidc", - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id="", - ), - ) - request = _build_callback_request("code", state, session) - - await handler.handle_oidc_callback(request) - - def _build_callback_request( code: str, state: str, diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 090cef5216..ebf653d018 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -465,9 +465,11 @@ class UIAuthTests(unittest.HomeserverTestCase): * checking that the original operation succeeds """ + fake_oidc_server = self.helper.fake_oidc_server() + # log the user in remote_user_id = UserID.from_string(self.user).localpart - login_resp = self.helper.login_via_oidc(remote_user_id) + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id) self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device @@ -481,8 +483,8 @@ class UIAuthTests(unittest.HomeserverTestCase): # run the UIA-via-SSO flow session_id = channel.json_body["session"] - channel = self.helper.auth_via_oidc( - {"sub": remote_user_id}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id ) # that should serve a confirmation page @@ -499,7 +501,8 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self) -> None: - login_resp = self.helper.login_via_oidc("username") + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -522,7 +525,10 @@ class UIAuthTests(unittest.HomeserverTestCase): @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) channel = self.delete_device( @@ -539,8 +545,13 @@ class UIAuthTests(unittest.HomeserverTestCase): @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" + + fake_oidc_server = self.helper.fake_oidc_server() + # log the user in - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) # start a UI Auth flow by attempting to delete a device @@ -553,8 +564,8 @@ class UIAuthTests(unittest.HomeserverTestCase): session_id = channel.json_body["session"] # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc( - {"sub": "wrong_user"}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id ) # that should return a failure message @@ -584,7 +595,10 @@ class UIAuthTests(unittest.HomeserverTestCase): """Tests that if we register a user via SSO while requiring approval for new accounts, we still raise the correct error before logging the user in. """ - login_resp = self.helper.login_via_oidc("username", expected_status=403) + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, "username", expected_status=403 + ) self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL) self.assertEqual( diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index e801ba8c8b..ff5baa9f0a 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -36,7 +36,7 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 -from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.rest.client.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless @@ -612,13 +612,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - # pick the default OIDC provider - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=oidc", - ) + fake_oidc_server = self.helper.fake_oidc_server() + + with fake_oidc_server.patch_homeserver(hs=self.hs): + # pick the default OIDC provider + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=oidc", + ) self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -626,7 +629,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") @@ -643,7 +646,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): TEST_CLIENT_REDIRECT_URL, ) - channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + channel, _ = self.helper.complete_oidc_auth( + fake_oidc_server, oidc_uri, cookies, {"sub": "user1"} + ) # that should serve a confirmation page self.assertEqual(channel.code, 200, channel.result) @@ -693,7 +698,10 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" - channel = self._make_sso_redirect_request("oidc") + fake_oidc_server = self.helper.fake_oidc_server() + + with fake_oidc_server.patch_homeserver(hs=self.hs): + channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -701,7 +709,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect @@ -1280,9 +1288,13 @@ class UsernamePickerTestCase(HomeserverTestCase): def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" + fake_oidc_server = self.helper.fake_oidc_server() + # do the start of the login flow - channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, + {"sub": "tester", "displayname": "Jonny"}, + TEST_CLIENT_REDIRECT_URL, ) # that should redirect to the username picker diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index c249a42bb6..967d229223 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -31,7 +31,6 @@ from typing import ( Tuple, overload, ) -from unittest.mock import patch from urllib.parse import urlencode import attr @@ -46,8 +45,19 @@ from synapse.server import HomeServer from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request -from tests.test_utils import FakeResponse from tests.test_utils.html_parsers import TestHtmlParser +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer + +# an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_ISSUER = "https://issuer.test/" +TEST_OIDC_CONFIG = { + "enabled": True, + "issuer": TEST_OIDC_ISSUER, + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["openid"], + "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, +} @attr.s(auto_attribs=True) @@ -543,12 +553,28 @@ class RestHelper: return channel.json_body + def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: + """Create a ``FakeOidcServer``. + + This can be used in conjuction with ``login_via_oidc``:: + + fake_oidc_server = self.helper.fake_oidc_server() + login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user") + """ + + return FakeOidcServer( + clock=self.hs.get_clock(), + issuer=issuer, + ) + def login_via_oidc( self, + fake_server: FakeOidcServer, remote_user_id: str, + with_sid: bool = False, expected_status: int = 200, - ) -> JsonDict: - """Log in via OIDC + ) -> Tuple[JsonDict, FakeAuthorizationGrant]: + """Log in (as a new user) via OIDC Returns the result of the final token login. @@ -560,7 +586,10 @@ class RestHelper: the normal places. """ client_redirect_url = "https://x" - channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) + userinfo = {"sub": remote_user_id} + channel, grant = self.auth_via_oidc( + fake_server, userinfo, client_redirect_url, with_sid=with_sid + ) # expect a confirmation page assert channel.code == HTTPStatus.OK, channel.result @@ -585,14 +614,16 @@ class RestHelper: assert ( channel.code == expected_status ), f"unexpected status in response: {channel.code}" - return channel.json_body + return channel.json_body, grant def auth_via_oidc( self, + fake_server: FakeOidcServer, user_info_dict: JsonDict, client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Perform an OIDC authentication flow via a mock OIDC provider. This can be used for either login or user-interactive auth. @@ -616,6 +647,7 @@ class RestHelper: the login redirect endpoint ui_auth_session_id: if set, we will perform a UI Auth flow. The session id of the UI auth. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. @@ -625,14 +657,15 @@ class RestHelper: cookies: Dict[str, str] = {} - # if we're doing a ui auth, hit the ui auth redirect endpoint - if ui_auth_session_id: - # can't set the client redirect url for UI Auth - assert client_redirect_url is None - oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) - else: - # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + with fake_server.patch_homeserver(hs=self.hs): + # if we're doing a ui auth, hit the ui auth redirect endpoint + if ui_auth_session_id: + # can't set the client redirect url for UI Auth + assert client_redirect_url is None + oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + else: + # otherwise, hit the login redirect endpoint + oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" @@ -640,17 +673,21 @@ class RestHelper: # that synapse passes to the client. oauth_uri_path, _ = oauth_uri.split("?", 1) - assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( + assert oauth_uri_path == fake_server.authorization_endpoint, ( "unexpected SSO URI " + oauth_uri_path ) - return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + return self.complete_oidc_auth( + fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid + ) def complete_oidc_auth( self, + fake_serer: FakeOidcServer, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Mock out an OIDC authentication flow Assumes that an OIDC auth has been initiated by one of initiate_sso_login or @@ -661,50 +698,37 @@ class RestHelper: Requires the OIDC callback resource to be mounted at the normal place. Args: + fake_server: the fake OIDC server with which the auth should be done oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, from initiate_sso_login or initiate_sso_ui_auth). cookies: the cookies set by synapse's redirect endpoint, which will be sent back to the callback endpoint. user_info_dict: the remote userinfo that the OIDC provider should present. Typically this should be '{"sub": ""}'. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. """ _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) + + code, grant = fake_serer.start_authorization( + scope=params["scope"][0], + userinfo=user_info_dict, + client_id=params["client_id"][0], + redirect_uri=params["redirect_uri"][0], + nonce=params["nonce"][0], + with_sid=with_sid, + ) + state = params["state"][0] + callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, - urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), + urllib.parse.urlencode({"state": state, "code": code}), ) - # before we hit the callback uri, stub out some methods in the http client so - # that we don't have to handle full HTTPS requests. - # (expected url, json response) pairs, in the order we expect them. - expected_requests = [ - # first we get a hit to the token endpoint, which we tell to return - # a dummy OIDC access token - (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), - # and then one to the user_info endpoint, which returns our remote user id. - (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), - ] - - async def mock_req( - method: str, - uri: str, - data: Optional[dict] = None, - headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): - (expected_uri, resp_obj) = expected_requests.pop(0) - assert uri == expected_uri - resp = FakeResponse( - code=HTTPStatus.OK, - phrase=b"OK", - body=json.dumps(resp_obj).encode("utf-8"), - ) - return resp - - with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): + with fake_serer.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code channel = make_request( self.hs.get_reactor(), @@ -715,7 +739,7 @@ class RestHelper: ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() ], ) - return channel + return channel, grant def initiate_sso_login( self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] @@ -806,21 +830,3 @@ class RestHelper: assert len(p.links) == 1, "not exactly one link in confirmation page" oauth_uri = p.links[0] return oauth_uri - - -# an 'oidc_config' suitable for login_via_oidc. -TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" -TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" -TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" -TEST_OIDC_CONFIG = { - "enabled": True, - "discover": False, - "issuer": "https://issuer.test", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "scopes": ["profile"], - "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, - "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, - "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, -} diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 0d0d6faf0d..e62ebcc6a5 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -15,17 +15,24 @@ """ Utilities for running the unit tests """ +import json import sys import warnings from asyncio import Future from binascii import unhexlify -from typing import Awaitable, Callable, TypeVar +from typing import Awaitable, Callable, Tuple, TypeVar from unittest.mock import Mock import attr +import zope.interface from twisted.python.failure import Failure from twisted.web.client import ResponseDone +from twisted.web.http import RESPONSES +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.types import JsonDict TV = TypeVar("TV") @@ -97,27 +104,44 @@ def simple_async_mock(return_value=None, raises=None) -> Mock: return Mock(side_effect=cb) -@attr.s -class FakeResponse: +# Type ignore: it does not fully implement IResponse, but is good enough for tests +@zope.interface.implementer(IResponse) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeResponse: # type: ignore[misc] """A fake twisted.web.IResponse object there is a similar class at treq.test.test_response, but it lacks a `phrase` attribute, and didn't support deliverBody until recently. """ - # HTTP response code - code = attr.ib(type=int) + version: Tuple[bytes, int, int] = (b"HTTP", 1, 1) - # HTTP response phrase (eg b'OK' for a 200) - phrase = attr.ib(type=bytes) + # HTTP response code + code: int = 200 # body of the response - body = attr.ib(type=bytes) + body: bytes = b"" + + headers: Headers = attr.Factory(Headers) + + @property + def phrase(self): + return RESPONSES.get(self.code, b"Unknown Status") + + @property + def length(self): + return len(self.body) def deliverBody(self, protocol): protocol.dataReceived(self.body) protocol.connectionLost(Failure(ResponseDone())) + @classmethod + def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse": + headers = Headers({"Content-Type": ["application/json"]}) + body = json.dumps(payload).encode("utf-8") + return cls(code=code, body=body, headers=headers) + # A small image used in some tests. # diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py new file mode 100644 index 0000000000..de134bbc89 --- /dev/null +++ b/tests/test_utils/oidc.py @@ -0,0 +1,325 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import Mock, patch +from urllib.parse import parse_qs + +import attr + +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.server import HomeServer +from synapse.util import Clock +from synapse.util.stringutils import random_string + +from tests.test_utils import FakeResponse + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeAuthorizationGrant: + userinfo: dict + client_id: str + redirect_uri: str + scope: str + nonce: Optional[str] + sid: Optional[str] + + +class FakeOidcServer: + """A fake OpenID Connect Provider.""" + + # All methods here are mocks, so we can track when they are called, and override + # their values + request: Mock + get_jwks_handler: Mock + get_metadata_handler: Mock + get_userinfo_handler: Mock + post_token_handler: Mock + + def __init__(self, clock: Clock, issuer: str): + from authlib.jose import ECKey, KeySet + + self._clock = clock + self.issuer = issuer + + self.request = Mock(side_effect=self._request) + self.get_jwks_handler = Mock(side_effect=self._get_jwks_handler) + self.get_metadata_handler = Mock(side_effect=self._get_metadata_handler) + self.get_userinfo_handler = Mock(side_effect=self._get_userinfo_handler) + self.post_token_handler = Mock(side_effect=self._post_token_handler) + + # A code -> grant mapping + self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {} + # An access token -> grant mapping + self._sessions: Dict[str, FakeAuthorizationGrant] = {} + + # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for + # signing JWTs. ECDSA keys are really quick to generate compared to RSA. + self._key = ECKey.generate_key(crv="P-256", is_private=True) + self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))]) + + self._id_token_overrides: Dict[str, Any] = {} + + def reset_mocks(self): + self.request.reset_mock() + self.get_jwks_handler.reset_mock() + self.get_metadata_handler.reset_mock() + self.get_userinfo_handler.reset_mock() + self.post_token_handler.reset_mock() + + def patch_homeserver(self, hs: HomeServer): + """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``. + + This patch should be used whenever the HS is expected to perform request to the + OIDC provider, e.g.:: + + fake_oidc_server = self.helper.fake_oidc_server() + with fake_oidc_server.patch_homeserver(hs): + self.make_request("GET", "/_matrix/client/r0/login/sso/redirect") + """ + return patch.object(hs.get_proxied_http_client(), "request", self.request) + + @property + def authorization_endpoint(self) -> str: + return self.issuer + "authorize" + + @property + def token_endpoint(self) -> str: + return self.issuer + "token" + + @property + def userinfo_endpoint(self) -> str: + return self.issuer + "userinfo" + + @property + def metadata_endpoint(self) -> str: + return self.issuer + ".well-known/openid-configuration" + + @property + def jwks_uri(self) -> str: + return self.issuer + "jwks" + + def get_metadata(self) -> dict: + return { + "issuer": self.issuer, + "authorization_endpoint": self.authorization_endpoint, + "token_endpoint": self.token_endpoint, + "jwks_uri": self.jwks_uri, + "userinfo_endpoint": self.userinfo_endpoint, + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["ES256"], + } + + def get_jwks(self) -> dict: + return self._jwks.as_dict() + + def get_userinfo(self, access_token: str) -> Optional[dict]: + """Given an access token, get the userinfo of the associated session.""" + session = self._sessions.get(access_token, None) + if session is None: + return None + return session.userinfo + + def _sign(self, payload: dict) -> str: + from authlib.jose import JsonWebSignature + + jws = JsonWebSignature() + kid = self.get_jwks()["keys"][0]["kid"] + protected = {"alg": "ES256", "kid": kid} + json_payload = json.dumps(payload) + return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") + + def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: + now = self._clock.time() + id_token = { + **grant.userinfo, + "iss": self.issuer, + "aud": grant.client_id, + "iat": now, + "nbf": now, + "exp": now + 600, + } + + if grant.nonce is not None: + id_token["nonce"] = grant.nonce + + if grant.sid is not None: + id_token["sid"] = grant.sid + + id_token.update(self._id_token_overrides) + + return self._sign(id_token) + + def id_token_override(self, overrides: dict): + """Temporarily patch the ID token generated by the token endpoint.""" + return patch.object(self, "_id_token_overrides", overrides) + + def start_authorization( + self, + client_id: str, + scope: str, + redirect_uri: str, + userinfo: dict, + nonce: Optional[str] = None, + with_sid: bool = False, + ) -> Tuple[str, FakeAuthorizationGrant]: + """Start an authorization request, and get back the code to use on the authorization endpoint.""" + code = random_string(10) + sid = None + if with_sid: + sid = random_string(10) + + grant = FakeAuthorizationGrant( + userinfo=userinfo, + scope=scope, + redirect_uri=redirect_uri, + nonce=nonce, + client_id=client_id, + sid=sid, + ) + self._authorization_grants[code] = grant + + return code, grant + + def exchange_code(self, code: str) -> Optional[Dict[str, Any]]: + grant = self._authorization_grants.pop(code, None) + if grant is None: + return None + + access_token = random_string(10) + self._sessions[access_token] = grant + + token = { + "token_type": "Bearer", + "access_token": access_token, + "expires_in": 3600, + "scope": grant.scope, + } + + if "openid" in grant.scope: + token["id_token"] = self.generate_id_token(grant) + + return dict(token) + + def buggy_endpoint( + self, + *, + jwks: bool = False, + metadata: bool = False, + token: bool = False, + userinfo: bool = False, + ): + """A context which makes a set of endpoints return a 500 error. + + Args: + jwks: If True, makes the JWKS endpoint return a 500 error. + metadata: If True, makes the OIDC Discovery endpoint return a 500 error. + token: If True, makes the token endpoint return a 500 error. + userinfo: If True, makes the userinfo endpoint return a 500 error. + """ + buggy = FakeResponse(code=500, body=b"Internal server error") + + patches = {} + if jwks: + patches["get_jwks_handler"] = Mock(return_value=buggy) + if metadata: + patches["get_metadata_handler"] = Mock(return_value=buggy) + if token: + patches["post_token_handler"] = Mock(return_value=buggy) + if userinfo: + patches["get_userinfo_handler"] = Mock(return_value=buggy) + + return patch.multiple(self, **patches) + + async def _request( + self, + method: str, + uri: str, + data: Optional[bytes] = None, + headers: Optional[Headers] = None, + ) -> IResponse: + """The override of the SimpleHttpClient#request() method""" + access_token: Optional[str] = None + + if headers is None: + headers = Headers() + + # Try to find the access token in the headers if any + auth_headers = headers.getRawHeaders(b"Authorization") + if auth_headers: + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + access_token = parts[1].decode("ascii") + + if method == "POST": + # If the method is POST, assume it has an url-encoded body + if data is None or headers.getRawHeaders(b"Content-Type") != [ + b"application/x-www-form-urlencoded" + ]: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + params = parse_qs(data.decode("utf-8")) + + if uri == self.token_endpoint: + # Even though this endpoint should be protected, this does not check + # for client authentication. We're not checking it for simplicity, + # and because client authentication is tested in other standalone tests. + return self.post_token_handler(params) + + elif method == "GET": + if uri == self.jwks_uri: + return self.get_jwks_handler() + elif uri == self.metadata_endpoint: + return self.get_metadata_handler() + elif uri == self.userinfo_endpoint: + return self.get_userinfo_handler(access_token=access_token) + + return FakeResponse(code=404, body=b"404 not found") + + # Request handlers + def _get_jwks_handler(self) -> IResponse: + """Handles requests to the JWKS URI.""" + return FakeResponse.json(payload=self.get_jwks()) + + def _get_metadata_handler(self) -> IResponse: + """Handles requests to the OIDC well-known document.""" + return FakeResponse.json(payload=self.get_metadata()) + + def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse: + """Handles requests to the userinfo endpoint.""" + if access_token is None: + return FakeResponse(code=401) + user_info = self.get_userinfo(access_token) + if user_info is None: + return FakeResponse(code=401) + + return FakeResponse.json(payload=user_info) + + def _post_token_handler(self, params: Dict[str, List[str]]) -> IResponse: + """Handles requests to the token endpoint.""" + code = params.get("code", []) + + if len(code) != 1: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + grant = self.exchange_code(code=code[0]) + if grant is None: + return FakeResponse.json(code=400, payload={"error": "invalid_grant"}) + + return FakeResponse.json(payload=grant) -- cgit 1.4.1 From d902181de98399d90c46c4e4e2cf631064757941 Mon Sep 17 00:00:00 2001 From: James Salter Date: Tue, 25 Oct 2022 19:05:22 +0100 Subject: Unified search query syntax using the full-text search capabilities of the underlying DB. (#11635) Support a unified search query syntax which leverages more of the full-text search of each database supported by Synapse. Supports, with the same syntax across Postgresql 11+ and Sqlite: - quoted "search terms" - `AND`, `OR`, `-` (negation) operators - Matching words based on their stem, e.g. searches for "dog" matches documents containing "dogs". This is achieved by - If on postgresql 11+, pass the user input to `websearch_to_tsquery` - If on sqlite, manually parse the query and transform it into the sqlite-specific query syntax. Note that postgresql 10, which is close to end-of-life, falls back to using `phraseto_tsquery`, which only supports a subset of the features. Multiple terms separated by a space are implicitly ANDed. Note that: 1. There is no escaping of full-text syntax that might be supported by the database; e.g. `NOT`, `NEAR`, `*` in sqlite. This runs the risk that people might discover this as accidental functionality and depend on something we don't guarantee. 2. English text is assumed for stemming. To support other languages, either the target language needs to be known at the time of indexing the message (via room metadata, or otherwise), or a separate index for each language supported could be created. Sqlite docs: https://www.sqlite.org/fts3.html#full_text_index_queries Postgres docs: https://www.postgresql.org/docs/11/textsearch-controls.html --- changelog.d/11635.feature | 1 + synapse/storage/databases/main/search.py | 197 +++++++++++++++---- synapse/storage/engines/postgres.py | 16 ++ .../delta/73/10_update_sqlite_fts4_tokenizer.py | 62 ++++++ tests/storage/test_room_search.py | 213 +++++++++++++++++++++ 5 files changed, 454 insertions(+), 35 deletions(-) create mode 100644 changelog.d/11635.feature create mode 100644 synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py (limited to 'synapse') diff --git a/changelog.d/11635.feature b/changelog.d/11635.feature new file mode 100644 index 0000000000..94c8a83212 --- /dev/null +++ b/changelog.d/11635.feature @@ -0,0 +1 @@ +Allow use of postgres and sqllite full-text search operators in search queries. \ No newline at end of file diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 1b79acf955..a89fc54c2c 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -11,10 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import logging import re -from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple +from collections import deque +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import attr @@ -27,7 +39,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import JsonDict if TYPE_CHECKING: @@ -421,8 +433,6 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = _parse_query(self.database_engine, search_term) - args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. @@ -444,20 +454,24 @@ class SearchStore(SearchBackgroundUpdateStore): count_clauses = clauses if isinstance(self.database_engine, PostgresEngine): + search_query = search_term + tsquery_func = self.database_engine.tsquery_func sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," + f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank," " room_id, event_id" " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" + f" WHERE vector @@ {tsquery_func}('english', ?)" ) args = [search_query, search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" + f" WHERE vector @@ {tsquery_func}('english', ?)" ) count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): + search_query = _parse_query_for_sqlite(search_term) + sql = ( "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" " FROM event_search" @@ -469,7 +483,7 @@ class SearchStore(SearchBackgroundUpdateStore): "SELECT room_id, count(*) as count FROM event_search" " WHERE value MATCH ?" ) - count_args = [search_term] + count_args + count_args = [search_query] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -501,7 +515,9 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres( + search_query, events, tsquery_func + ) count_sql += " GROUP BY room_id" @@ -510,7 +526,6 @@ class SearchStore(SearchBackgroundUpdateStore): ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) - return { "results": [ {"event": event_map[r["event_id"]], "rank": r["rank"]} @@ -542,9 +557,6 @@ class SearchStore(SearchBackgroundUpdateStore): Each match as a dictionary. """ clauses = [] - - search_query = _parse_query(self.database_engine, search_term) - args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. @@ -582,20 +594,23 @@ class SearchStore(SearchBackgroundUpdateStore): args.extend([origin_server_ts, origin_server_ts, stream]) if isinstance(self.database_engine, PostgresEngine): + search_query = search_term + tsquery_func = self.database_engine.tsquery_func sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," + f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank," " origin_server_ts, stream_ordering, room_id, event_id" " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " + f" WHERE vector @@ {tsquery_func}('english', ?) AND " ) args = [search_query, search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " + f" WHERE vector @@ {tsquery_func}('english', ?) AND " ) count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): + # We use CROSS JOIN here to ensure we use the right indexes. # https://sqlite.org/optoverview.html#crossjoin # @@ -614,13 +629,14 @@ class SearchStore(SearchBackgroundUpdateStore): " CROSS JOIN events USING (event_id)" " WHERE " ) + search_query = _parse_query_for_sqlite(search_term) args = [search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" " WHERE value MATCH ? AND " ) - count_args = [search_term] + count_args + count_args = [search_query] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -660,7 +676,9 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres( + search_query, events, tsquery_func + ) count_sql += " GROUP BY room_id" @@ -686,7 +704,7 @@ class SearchStore(SearchBackgroundUpdateStore): } async def _find_highlights_in_postgres( - self, search_query: str, events: List[EventBase] + self, search_query: str, events: List[EventBase], tsquery_func: str ) -> Set[str]: """Given a list of events and a search term, return a list of words that match from the content of the event. @@ -697,6 +715,7 @@ class SearchStore(SearchBackgroundUpdateStore): Args: search_query events: A list of events + tsquery_func: The tsquery_* function to use when making queries Returns: A set of strings. @@ -729,7 +748,7 @@ class SearchStore(SearchBackgroundUpdateStore): while stop_sel in value: stop_sel += ">" - query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( + query = f"SELECT ts_headline(?, {tsquery_func}('english', ?), %s)" % ( _to_postgres_options( { "StartSel": start_sel, @@ -760,20 +779,128 @@ def _to_postgres_options(options_dict: JsonDict) -> str: return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) -def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str: - """Takes a plain unicode string from the user and converts it into a form - that can be passed to database. - We use this so that we can add prefix matching, which isn't something - that is supported by default. +@dataclass +class Phrase: + phrase: List[str] + + +class SearchToken(enum.Enum): + Not = enum.auto() + Or = enum.auto() + And = enum.auto() + + +Token = Union[str, Phrase, SearchToken] +TokenList = List[Token] + + +def _is_stop_word(word: str) -> bool: + # TODO Pull these out of the dictionary: + # https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop + return word in {"the", "a", "you", "me", "and", "but"} + + +def _tokenize_query(query: str) -> TokenList: + """ + Convert the user-supplied `query` into a TokenList, which can be translated into + some DB-specific syntax. + + The following constructs are supported: + + - phrase queries using "double quotes" + - case-insensitive `or` and `and` operators + - negation of a keyword via unary `-` + - unary hyphen to denote NOT e.g. 'include -exclude' + + The following differs from websearch_to_tsquery: + + - Stop words are not removed. + - Unclosed phrases are treated differently. + + """ + tokens: TokenList = [] + + # Find phrases. + in_phrase = False + parts = deque(query.split('"')) + for i, part in enumerate(parts): + # The contents inside double quotes is treated as a phrase, a trailing + # double quote is not implied. + in_phrase = bool(i % 2) and i != (len(parts) - 1) + + # Pull out the individual words, discarding any non-word characters. + words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE)) + + # Phrases have simplified handling of words. + if in_phrase: + # Skip stop words. + phrase = [word for word in words if not _is_stop_word(word)] + + # Consecutive words are implicitly ANDed together. + if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or): + tokens.append(SearchToken.And) + + # Add the phrase. + tokens.append(Phrase(phrase)) + continue + + # Otherwise, not in a phrase. + while words: + word = words.popleft() + + if word.startswith("-"): + tokens.append(SearchToken.Not) + + # If there's more word, put it back to be processed again. + word = word[1:] + if word: + words.appendleft(word) + elif word.lower() == "or": + tokens.append(SearchToken.Or) + else: + # Skip stop words. + if _is_stop_word(word): + continue + + # Consecutive words are implicitly ANDed together. + if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or): + tokens.append(SearchToken.And) + + # Add the search term. + tokens.append(word) + + return tokens + + +def _tokens_to_sqlite_match_query(tokens: TokenList) -> str: + """ + Convert the list of tokens to a string suitable for passing to sqlite's MATCH. + Assume sqlite was compiled with enhanced query syntax. + + Ref: https://www.sqlite.org/fts3.html#full_text_index_queries """ + match_query = [] + for token in tokens: + if isinstance(token, str): + match_query.append(token) + elif isinstance(token, Phrase): + match_query.append('"' + " ".join(token.phrase) + '"') + elif token == SearchToken.Not: + # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search + # term has already been added before this. + match_query.append(" NOT ") + elif token == SearchToken.Or: + match_query.append(" OR ") + elif token == SearchToken.And: + match_query.append(" AND ") + else: + raise ValueError(f"unknown token {token}") + + return "".join(match_query) - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - if isinstance(database_engine, PostgresEngine): - return " & ".join(result + ":*" for result in results) - elif isinstance(database_engine, Sqlite3Engine): - return " & ".join(result + "*" for result in results) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") +def _parse_query_for_sqlite(search_term: str) -> str: + """Takes a plain unicode string from the user and converts it into a form + that can be passed to sqllite's matchinfo(). + """ + return _tokens_to_sqlite_match_query(_tokenize_query(search_term)) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index d8c0f64d9a..9bf74bbf59 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -170,6 +170,22 @@ class PostgresEngine( """Do we support the `RETURNING` clause in insert/update/delete?""" return True + @property + def tsquery_func(self) -> str: + """ + Selects a tsquery_* func to use. + + Ref: https://www.postgresql.org/docs/current/textsearch-controls.html + + Returns: + The function name. + """ + # Postgres 11 added support for websearch_to_tsquery. + assert self._version is not None + if self._version >= 110000: + return "websearch_to_tsquery" + return "plainto_tsquery" + def is_deadlock(self, error: Exception) -> bool: if isinstance(error, psycopg2.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html diff --git a/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py new file mode 100644 index 0000000000..3de0a709eb --- /dev/null +++ b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py @@ -0,0 +1,62 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine +from synapse.storage.types import Cursor + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine) -> None: + """ + Upgrade the event_search table to use the porter tokenizer if it isn't already + + Applies only for sqlite. + """ + if not isinstance(database_engine, Sqlite3Engine): + return + + # Rebuild the table event_search table with tokenize=porter configured. + cur.execute("DROP TABLE event_search") + cur.execute( + """ + CREATE VIRTUAL TABLE event_search + USING fts4 (tokenize=porter, event_id, room_id, sender, key, value ) + """ + ) + + # Re-run the background job to re-populate the event_search table. + cur.execute("SELECT MIN(stream_ordering) FROM events") + row = cur.fetchone() + min_stream_id = row[0] + + # If there are not any events, nothing to do. + if min_stream_id is None: + return + + cur.execute("SELECT MAX(stream_ordering) FROM events") + row = cur.fetchone() + max_stream_id = row[0] + + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + } + progress_json = json.dumps(progress) + + sql = """ + INSERT into background_updates (ordering, update_name, progress_json) + VALUES (?, ?, ?) + """ + + cur.execute(sql, (7310, "event_search", progress_json)) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index e747c6b50e..9ddc19900a 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -12,11 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple, Union +from unittest.case import SkipTest +from unittest.mock import PropertyMock, patch + +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.storage.databases.main import DataStore +from synapse.storage.databases.main.search import Phrase, SearchToken, _tokenize_query from synapse.storage.engines import PostgresEngine +from synapse.storage.engines.sqlite import Sqlite3Engine +from synapse.util import Clock from tests.unittest import HomeserverTestCase, skip_unless from tests.utils import USE_POSTGRES_FOR_TESTS @@ -187,3 +198,205 @@ class EventSearchInsertionTest(HomeserverTestCase): ), ) self.assertCountEqual(values, ["hi", "2"]) + + +class MessageSearchTest(HomeserverTestCase): + """ + Check message search. + + A powerful way to check the behaviour is to run the following in Postgres >= 11: + + # SELECT websearch_to_tsquery('english', ); + + The result can be compared to the tokenized version for SQLite and Postgres < 11. + + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + + PHRASE = "the quick brown fox jumps over the lazy dog" + + # Each entry is a search query, followed by either a boolean of whether it is + # in the phrase OR a tuple of booleans: whether it matches using websearch + # and using plain search. + COMMON_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + ("nope", False), + ("brown", True), + ("quick brown", True), + ("brown quick", True), + ("quick \t brown", True), + ("jump", True), + ("brown nope", False), + ('"brown quick"', (False, True)), + ('"jumps over"', True), + ('"quick fox"', (False, True)), + ("nope OR doublenope", False), + ("furphy OR fox", (True, False)), + ("fox -nope", (True, False)), + ("fox -brown", (False, True)), + ('"fox" quick', True), + ('"fox quick', True), + ('"quick brown', True), + ('" quick "', True), + ('" nope"', False), + ] + # TODO Test non-ASCII cases. + + # Case that fail on SQLite. + POSTGRES_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + # SQLite treats NOT as a binary operator. + ("- fox", (False, True)), + ("- nope", (True, False)), + ('"-fox quick', (False, True)), + # PostgreSQL skips stop words. + ('"the quick brown"', True), + ('"over lazy"', True), + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Register a user and create a room, create some messages + self.register_user("alice", "password") + self.access_token = self.login("alice", "password") + self.room_id = self.helper.create_room_as("alice", tok=self.access_token) + + # Send the phrase as a message and check it was created + response = self.helper.send(self.room_id, self.PHRASE, tok=self.access_token) + self.assertIn("event_id", response) + + def test_tokenize_query(self) -> None: + """Test the custom logic to tokenize a user's query.""" + cases = ( + ("brown", ["brown"]), + ("quick brown", ["quick", SearchToken.And, "brown"]), + ("quick \t brown", ["quick", SearchToken.And, "brown"]), + ('"brown quick"', [Phrase(["brown", "quick"])]), + ("furphy OR fox", ["furphy", SearchToken.Or, "fox"]), + ("fox -brown", ["fox", SearchToken.Not, "brown"]), + ("- fox", [SearchToken.Not, "fox"]), + ('"fox" quick', [Phrase(["fox"]), SearchToken.And, "quick"]), + # No trailing double quoe. + ('"fox quick', ["fox", SearchToken.And, "quick"]), + ('"-fox quick', [SearchToken.Not, "fox", SearchToken.And, "quick"]), + ('" quick "', [Phrase(["quick"])]), + ( + 'q"uick brow"n', + [ + "q", + SearchToken.And, + Phrase(["uick", "brow"]), + SearchToken.And, + "n", + ], + ), + ( + '-"quick brown"', + [SearchToken.Not, Phrase(["quick", "brown"])], + ), + ) + + for query, expected in cases: + tokenized = _tokenize_query(query) + self.assertEqual( + tokenized, expected, f"{tokenized} != {expected} for {query}" + ) + + def _check_test_cases( + self, + store: DataStore, + cases: List[Tuple[str, Union[bool, Tuple[bool, bool]]]], + index=0, + ) -> None: + # Run all the test cases versus search_msgs + for query, expect_to_contain in cases: + if isinstance(expect_to_contain, tuple): + expect_to_contain = expect_to_contain[index] + + result = self.get_success( + store.search_msgs([self.room_id], query, ["content.body"]) + ) + self.assertEquals( + result["count"], + 1 if expect_to_contain else 0, + f"expected '{query}' to match '{self.PHRASE}'" + if expect_to_contain + else f"'{query}' unexpectedly matched '{self.PHRASE}'", + ) + self.assertEquals( + len(result["results"]), + 1 if expect_to_contain else 0, + "results array length should match count", + ) + + # Run them again versus search_rooms + for query, expect_to_contain in cases: + if isinstance(expect_to_contain, tuple): + expect_to_contain = expect_to_contain[index] + + result = self.get_success( + store.search_rooms([self.room_id], query, ["content.body"], 10) + ) + self.assertEquals( + result["count"], + 1 if expect_to_contain else 0, + f"expected '{query}' to match '{self.PHRASE}'" + if expect_to_contain + else f"'{query}' unexpectedly matched '{self.PHRASE}'", + ) + self.assertEquals( + len(result["results"]), + 1 if expect_to_contain else 0, + "results array length should match count", + ) + + def test_postgres_web_search_for_phrase(self): + """ + Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery. + This test is skipped unless the postgres instance supports websearch_to_tsquery. + """ + + store = self.hs.get_datastores().main + if not isinstance(store.database_engine, PostgresEngine): + raise SkipTest("Test only applies when postgres is used as the database") + + if store.database_engine.tsquery_func != "websearch_to_tsquery": + raise SkipTest( + "Test only applies when postgres supporting websearch_to_tsquery is used as the database" + ) + + self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES, index=0) + + def test_postgres_non_web_search_for_phrase(self): + """ + Test postgres searching for phrases without using web search, which is used when websearch_to_tsquery isn't + supported by the current postgres version. + """ + + store = self.hs.get_datastores().main + if not isinstance(store.database_engine, PostgresEngine): + raise SkipTest("Test only applies when postgres is used as the database") + + # Patch supports_websearch_to_tsquery to always return False to ensure we're testing the plainto_tsquery path. + with patch( + "synapse.storage.engines.postgres.PostgresEngine.tsquery_func", + new_callable=PropertyMock, + ) as supports_websearch_to_tsquery: + supports_websearch_to_tsquery.return_value = "plainto_tsquery" + self._check_test_cases( + store, self.COMMON_CASES + self.POSTGRES_CASES, index=1 + ) + + def test_sqlite_search(self): + """ + Test sqlite searching for phrases. + """ + store = self.hs.get_datastores().main + if not isinstance(store.database_engine, Sqlite3Engine): + raise SkipTest("Test only applies when sqlite is used as the database") + + self._check_test_cases(store, self.COMMON_CASES, index=0) -- cgit 1.4.1 From 8756d5c87efc5637da55c9e21d2a4eb2369ba693 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 26 Oct 2022 12:45:41 +0200 Subject: Save login tokens in database (#13844) * Save login tokens in database Signed-off-by: Quentin Gliech * Add upgrade notes * Track login token reuse in a Prometheus metric Signed-off-by: Quentin Gliech --- changelog.d/13844.misc | 1 + docs/upgrade.md | 9 ++ synapse/handlers/auth.py | 64 +++++++-- synapse/module_api/__init__.py | 41 +----- synapse/rest/client/login.py | 3 +- synapse/rest/client/login_token_request.py | 5 +- synapse/storage/databases/main/registration.py | 156 ++++++++++++++++++++- .../schema/main/delta/73/10login_tokens.sql | 35 +++++ synapse/util/macaroons.py | 87 +----------- tests/handlers/test_auth.py | 135 ++++++++++-------- tests/util/test_macaroons.py | 28 ---- 11 files changed, 337 insertions(+), 227 deletions(-) create mode 100644 changelog.d/13844.misc create mode 100644 synapse/storage/schema/main/delta/73/10login_tokens.sql (limited to 'synapse') diff --git a/changelog.d/13844.misc b/changelog.d/13844.misc new file mode 100644 index 0000000000..66f4414df7 --- /dev/null +++ b/changelog.d/13844.misc @@ -0,0 +1 @@ +Save login tokens in database and prevent login token reuse. diff --git a/docs/upgrade.md b/docs/upgrade.md index b81385b191..78c34d0c15 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,15 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.71.0 + +## Removal of the `generate_short_term_login_token` module API method + +As announced with the release of [Synapse 1.69.0](#deprecation-of-the-generate_short_term_login_token-module-api-method), the deprecated `generate_short_term_login_token` module method has been removed. + +Modules relying on it can instead use the `create_login_token` method. + + # Upgrading to v1.69.0 ## Changes to the receipts replication streams diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f5f0e0e7a7..8b9ef25d29 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -38,6 +38,7 @@ from typing import ( import attr import bcrypt import unpaddedbase64 +from prometheus_client import Counter from twisted.internet.defer import CancelledError from twisted.web.server import Request @@ -48,6 +49,7 @@ from synapse.api.errors import ( Codes, InteractiveAuthIncompleteError, LoginError, + NotFoundError, StoreError, SynapseError, UserDeactivatedError, @@ -63,10 +65,14 @@ from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.registration import ( + LoginTokenExpired, + LoginTokenLookupResult, + LoginTokenReused, +) from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import delay_cancellation, maybe_awaitable -from synapse.util.macaroons import LoginTokenAttributes from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import base62_encode from synapse.util.threepids import canonicalise_email @@ -80,6 +86,12 @@ logger = logging.getLogger(__name__) INVALID_USERNAME_OR_PASSWORD = "Invalid username or password" +invalid_login_token_counter = Counter( + "synapse_user_login_invalid_login_tokens", + "Counts the number of rejected m.login.token on /login", + ["reason"], +) + def convert_client_dict_legacy_fields_to_identifier( submission: JsonDict, @@ -883,6 +895,25 @@ class AuthHandler: return True + async def create_login_token_for_user_id( + self, + user_id: str, + duration_ms: int = (2 * 60 * 1000), + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, + ) -> str: + login_token = self.generate_login_token() + now = self._clock.time_msec() + expiry_ts = now + duration_ms + await self.store.add_login_token_to_user( + user_id=user_id, + token=login_token, + expiry_ts=expiry_ts, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + return login_token + async def create_refresh_token_for_user_id( self, user_id: str, @@ -1401,6 +1432,18 @@ class AuthHandler: return None return user_id + def generate_login_token(self) -> str: + """Generates an opaque string, for use as an short-term login token""" + + # we use the following format for access tokens: + # syl__ + + random_string = stringutils.random_string(20) + base = f"syl_{random_string}" + + crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) + return f"{base}_{crc}" + def generate_access_token(self, for_user: UserID) -> str: """Generates an opaque string, for use as an access token""" @@ -1427,16 +1470,17 @@ class AuthHandler: crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) return f"{base}_{crc}" - async def validate_short_term_login_token( - self, login_token: str - ) -> LoginTokenAttributes: + async def consume_login_token(self, login_token: str) -> LoginTokenLookupResult: try: - res = self.macaroon_gen.verify_short_term_login_token(login_token) - except Exception: - raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) + return await self.store.consume_login_token(login_token) + except LoginTokenExpired: + invalid_login_token_counter.labels("expired").inc() + except LoginTokenReused: + invalid_login_token_counter.labels("reused").inc() + except NotFoundError: + invalid_login_token_counter.labels("not found").inc() - await self.auth_blocking.check_auth_blocking(res.user_id) - return res + raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) async def delete_access_token(self, access_token: str) -> None: """Invalidate a single access token @@ -1711,7 +1755,7 @@ class AuthHandler: ) # Create a login token - login_token = self.macaroon_gen.generate_short_term_login_token( + login_token = await self.create_login_token_for_user_id( registered_user_id, auth_provider_id=auth_provider_id, auth_provider_session_id=auth_provider_session_id, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6a6ae208d1..30e689d00d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -771,50 +771,11 @@ class ModuleApi: auth_provider_session_id: The session ID got during login from the SSO IdP, if any. """ - # The deprecated `generate_short_term_login_token` method defaulted to an empty - # string for the `auth_provider_id` because of how the underlying macaroon was - # generated. This will change to a proper NULL-able field when the tokens get - # moved to the database. - return self._hs.get_macaroon_generator().generate_short_term_login_token( + return await self._hs.get_auth_handler().create_login_token_for_user_id( user_id, - auth_provider_id or "", - auth_provider_session_id, duration_in_ms, - ) - - def generate_short_term_login_token( - self, - user_id: str, - duration_in_ms: int = (2 * 60 * 1000), - auth_provider_id: str = "", - auth_provider_session_id: Optional[str] = None, - ) -> str: - """Generate a login token suitable for m.login.token authentication - - Added in Synapse v1.9.0. - - This was deprecated in Synapse v1.69.0 in favor of create_login_token, and will - be removed in Synapse 1.71.0. - - Args: - user_id: gives the ID of the user that the token is for - - duration_in_ms: the time that the token will be valid for - - auth_provider_id: the ID of the SSO IdP that the user used to authenticate - to get this token, if any. This is encoded in the token so that - /login can report stats on number of successful logins by IdP. - """ - logger.warn( - "A module configured on this server uses ModuleApi.generate_short_term_login_token(), " - "which is deprecated in favor of ModuleApi.create_login_token(), and will be removed in " - "Synapse 1.71.0", - ) - return self._hs.get_macaroon_generator().generate_short_term_login_token( - user_id, auth_provider_id, auth_provider_session_id, - duration_in_ms, ) @defer.inlineCallbacks diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index f554586ac3..7774f1967d 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -436,8 +436,7 @@ class LoginRestServlet(RestServlet): The body of the JSON response. """ token = login_submission["token"] - auth_handler = self.auth_handler - res = await auth_handler.validate_short_term_login_token(token) + res = await self.auth_handler.consume_login_token(token) return await self._complete_login( res.user_id, diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py index 277b20fb63..43ea21d5e6 100644 --- a/synapse/rest/client/login_token_request.py +++ b/synapse/rest/client/login_token_request.py @@ -57,7 +57,6 @@ class LoginTokenRequestServlet(RestServlet): self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.config.server.server_name - self.macaroon_gen = hs.get_macaroon_generator() self.auth_handler = hs.get_auth_handler() self.token_timeout = hs.config.experimental.msc3882_token_timeout self.ui_auth = hs.config.experimental.msc3882_ui_auth @@ -76,10 +75,10 @@ class LoginTokenRequestServlet(RestServlet): can_skip_ui_auth=False, # Don't allow skipping of UI auth ) - login_token = self.macaroon_gen.generate_short_term_login_token( + login_token = await self.auth_handler.create_login_token_for_user_id( user_id=requester.user.to_string(), auth_provider_id="org.matrix.msc3882.login_token_request", - duration_in_ms=self.token_timeout, + duration_ms=self.token_timeout, ) return ( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 2996d6bb4d..0255295317 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast import attr from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError +from synapse.api.errors import ( + Codes, + NotFoundError, + StoreError, + SynapseError, + ThreepidValidationError, +) from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import ( @@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception): because this external id is given to an other user.""" +class LoginTokenExpired(Exception): + """Exception if the login token sent expired""" + + +class LoginTokenReused(Exception): + """Exception if the login token sent was already used""" + + @attr.s(frozen=True, slots=True, auto_attribs=True) class TokenLookupResult: """Result of looking up an access token. @@ -115,6 +129,20 @@ class RefreshTokenLookupResult: If None, the session can be refreshed indefinitely.""" +@attr.s(auto_attribs=True, frozen=True, slots=True) +class LoginTokenLookupResult: + """Result of looking up a login token.""" + + user_id: str + """The user this token belongs to.""" + + auth_provider_id: Optional[str] + """The SSO Identity Provider that the user authenticated with, to get this token.""" + + auth_provider_session_id: Optional[str] + """The session ID advertised by the SSO Identity Provider.""" + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -1789,6 +1817,109 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "replace_refresh_token", _replace_refresh_token_txn ) + async def add_login_token_to_user( + self, + user_id: str, + token: str, + expiry_ts: int, + auth_provider_id: Optional[str], + auth_provider_session_id: Optional[str], + ) -> None: + """Adds a short-term login token for the given user. + + Args: + user_id: The user ID. + token: The new login token to add. + expiry_ts (milliseconds since the epoch): Time after which the login token + cannot be used. + auth_provider_id: The SSO Identity Provider that the user authenticated with + to get this token, if any + auth_provider_session_id: The session ID advertised by the SSO Identity + Provider, if any. + """ + await self.db_pool.simple_insert( + "login_tokens", + { + "token": token, + "user_id": user_id, + "expiry_ts": expiry_ts, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + desc="add_login_token_to_user", + ) + + def _consume_login_token( + self, + txn: LoggingTransaction, + token: str, + ts: int, + ) -> LoginTokenLookupResult: + values = self.db_pool.simple_select_one_txn( + txn, + "login_tokens", + keyvalues={"token": token}, + retcols=( + "user_id", + "expiry_ts", + "used_ts", + "auth_provider_id", + "auth_provider_session_id", + ), + allow_none=True, + ) + + if values is None: + raise NotFoundError() + + self.db_pool.simple_update_one_txn( + txn, + "login_tokens", + keyvalues={"token": token}, + updatevalues={"used_ts": ts}, + ) + user_id = values["user_id"] + expiry_ts = values["expiry_ts"] + used_ts = values["used_ts"] + auth_provider_id = values["auth_provider_id"] + auth_provider_session_id = values["auth_provider_session_id"] + + # Token was already used + if used_ts is not None: + raise LoginTokenReused() + + # Token expired + if ts > int(expiry_ts): + raise LoginTokenExpired() + + return LoginTokenLookupResult( + user_id=user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + async def consume_login_token(self, token: str) -> LoginTokenLookupResult: + """Lookup a login token and consume it. + + Args: + token: The login token. + + Returns: + The data stored with that token, including the `user_id`. Returns `None` if + the token does not exist or if it expired. + + Raises: + NotFound if the login token was not found in database + LoginTokenExpired if the login token expired + LoginTokenReused if the login token was already used + """ + return await self.db_pool.runInteraction( + "consume_login_token", + self._consume_login_token, + token, + self._clock.time_msec(), + ) + @cached() async def is_guest(self, user_id: str) -> bool: res = await self.db_pool.simple_select_one_onecol( @@ -2019,6 +2150,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): and hs.config.experimental.msc3866.require_approval_for_new_accounts ) + # Create a background job for removing expired login tokens + if hs.config.worker.run_background_tasks: + self._clock.looping_call( + self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS + ) + async def add_access_token_to_user( self, user_id: str, @@ -2617,6 +2754,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): approved, ) + @wrap_as_background_process("delete_expired_login_tokens") + async def _delete_expired_login_tokens(self) -> None: + """Remove login tokens with expiry dates that have passed.""" + + def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None: + sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?" + txn.execute(sql, (ts,)) + + # We keep the expired tokens for an extra 5 minutes so we can measure how many + # times a token is being used after its expiry + now = self._clock.time_msec() + await self.db_pool.runInteraction( + "delete_expired_login_tokens", + _delete_expired_login_tokens_txn, + now - (5 * 60 * 1000), + ) + def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/schema/main/delta/73/10login_tokens.sql b/synapse/storage/schema/main/delta/73/10login_tokens.sql new file mode 100644 index 0000000000..a39b7bcece --- /dev/null +++ b/synapse/storage/schema/main/delta/73/10login_tokens.sql @@ -0,0 +1,35 @@ +/* + * Copyright 2022 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Login tokens are short-lived tokens that are used for the m.login.token +-- login method, mainly during SSO logins +CREATE TABLE login_tokens ( + token TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + expiry_ts BIGINT NOT NULL, + used_ts BIGINT, + auth_provider_id TEXT, + auth_provider_session_id TEXT +); + +-- We're sometimes querying them by their session ID we got from their IDP +CREATE INDEX login_tokens_auth_provider_idx + ON login_tokens (auth_provider_id, auth_provider_session_id); + +-- We're deleting them by their expiration time +CREATE INDEX login_tokens_expiry_time_idx + ON login_tokens (expiry_ts); + diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index df77edcce2..5df03d3ddc 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -24,7 +24,7 @@ from typing_extensions import Literal from synapse.util import Clock, stringutils -MacaroonType = Literal["access", "delete_pusher", "session", "login"] +MacaroonType = Literal["access", "delete_pusher", "session"] def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: @@ -111,19 +111,6 @@ class OidcSessionData: """The session ID of the ongoing UI Auth ("" if this is a login)""" -@attr.s(slots=True, frozen=True, auto_attribs=True) -class LoginTokenAttributes: - """Data we store in a short-term login token""" - - user_id: str - - auth_provider_id: str - """The SSO Identity Provider that the user authenticated with, to get this token.""" - - auth_provider_session_id: Optional[str] - """The session ID advertised by the SSO Identity Provider.""" - - class MacaroonGenerator: def __init__(self, clock: Clock, location: str, secret_key: bytes): self._clock = clock @@ -165,35 +152,6 @@ class MacaroonGenerator: macaroon.add_first_party_caveat(f"pushkey = {pushkey}") return macaroon.serialize() - def generate_short_term_login_token( - self, - user_id: str, - auth_provider_id: str, - auth_provider_session_id: Optional[str] = None, - duration_in_ms: int = (2 * 60 * 1000), - ) -> str: - """Generate a short-term login token used during SSO logins - - Args: - user_id: The user for which the token is valid. - auth_provider_id: The SSO IdP the user used. - auth_provider_session_id: The session ID got during login from the SSO IdP. - - Returns: - A signed token valid for using as a ``m.login.token`` token. - """ - now = self._clock.time_msec() - expiry = now + duration_in_ms - macaroon = self._generate_base_macaroon("login") - macaroon.add_first_party_caveat(f"user_id = {user_id}") - macaroon.add_first_party_caveat(f"time < {expiry}") - macaroon.add_first_party_caveat(f"auth_provider_id = {auth_provider_id}") - if auth_provider_session_id is not None: - macaroon.add_first_party_caveat( - f"auth_provider_session_id = {auth_provider_session_id}" - ) - return macaroon.serialize() - def generate_oidc_session_token( self, state: str, @@ -233,49 +191,6 @@ class MacaroonGenerator: return macaroon.serialize() - def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: - """Verify a short-term-login macaroon - - Checks that the given token is a valid, unexpired short-term-login token - minted by this server. - - Args: - token: The login token to verify. - - Returns: - A set of attributes carried by this token, including the - ``user_id`` and informations about the SSO IDP used during that - login. - - Raises: - MacaroonVerificationFailedException if the verification failed - """ - macaroon = pymacaroons.Macaroon.deserialize(token) - - v = self._base_verifier("login") - v.satisfy_general(lambda c: c.startswith("user_id = ")) - v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) - v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = ")) - satisfy_expiry(v, self._clock.time_msec) - v.verify(macaroon, self._secret_key) - - user_id = get_value_from_macaroon(macaroon, "user_id") - auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") - - auth_provider_session_id: Optional[str] = None - try: - auth_provider_session_id = get_value_from_macaroon( - macaroon, "auth_provider_session_id" - ) - except MacaroonVerificationFailedException: - pass - - return LoginTokenAttributes( - user_id=user_id, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, - ) - def verify_guest_token(self, token: str) -> str: """Verify a guest access token macaroon diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 7106799d44..036dbbc45b 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from unittest.mock import Mock import pymacaroons @@ -19,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import AuthError, ResourceLimitError from synapse.rest import admin +from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -29,6 +31,7 @@ from tests.test_utils import make_awaitable class AuthTestCase(unittest.HomeserverTestCase): servlets = [ admin.register_servlets, + login.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -46,6 +49,23 @@ class AuthTestCase(unittest.HomeserverTestCase): self.user1 = self.register_user("a_user", "pass") + def token_login(self, token: str) -> Optional[str]: + body = { + "type": "m.login.token", + "token": token, + } + + channel = self.make_request( + "POST", + "/_matrix/client/v3/login", + body, + ) + + if channel.code == 200: + return channel.json_body["user_id"] + + return None + def test_macaroon_caveats(self) -> None: token = self.macaroon_generator.generate_guest_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) @@ -73,49 +93,62 @@ class AuthTestCase(unittest.HomeserverTestCase): v.satisfy_general(verify_guest) v.verify(macaroon, self.hs.config.key.macaroon_secret_key) - def test_short_term_login_token_gives_user_id(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + def test_login_token_gives_user_id(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) + + res = self.get_success(self.auth_handler.consume_login_token(token)) self.assertEqual(self.user1, res.user_id) - self.assertEqual("", res.auth_provider_id) + self.assertEqual(None, res.auth_provider_id) - # when we advance the clock, the token should be rejected - self.reactor.advance(6) - self.get_failure( - self.auth_handler.validate_short_term_login_token(token), - AuthError, + def test_login_token_reuse_fails(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - def test_short_term_login_token_gives_auth_provider(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, auth_provider_id="my_idp" - ) - res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) - self.assertEqual(self.user1, res.user_id) - self.assertEqual("my_idp", res.auth_provider_id) + self.get_success(self.auth_handler.consume_login_token(token)) - def test_short_term_login_token_cannot_replace_user_id(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + self.get_failure( + self.auth_handler.consume_login_token(token), + AuthError, ) - macaroon = pymacaroons.Macaroon.deserialize(token) - res = self.get_success( - self.auth_handler.validate_short_term_login_token(macaroon.serialize()) + def test_login_token_expires(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - self.assertEqual(self.user1, res.user_id) - - # add another "user_id" caveat, which might allow us to override the - # user_id. - macaroon.add_first_party_caveat("user_id = b_user") + # when we advance the clock, the token should be rejected + self.reactor.advance(6) self.get_failure( - self.auth_handler.validate_short_term_login_token(macaroon.serialize()), + self.auth_handler.consume_login_token(token), AuthError, ) + def test_login_token_gives_auth_provider(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + auth_provider_id="my_idp", + auth_provider_session_id="11-22-33-44", + duration_ms=(5 * 1000), + ) + ) + res = self.get_success(self.auth_handler.consume_login_token(token)) + self.assertEqual(self.user1, res.user_id) + self.assertEqual("my_idp", res.auth_provider_id) + self.assertEqual("11-22-33-44", res.auth_provider_session_id) + def test_mau_limits_disabled(self) -> None: self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception @@ -125,12 +158,12 @@ class AuthTestCase(unittest.HomeserverTestCase): ) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNotNone(self.token_login(token)) + def test_mau_limits_exceeded_large(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastores().main.get_monthly_active_count = Mock( @@ -147,12 +180,10 @@ class AuthTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) - self.get_failure( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ), - ResourceLimitError, + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNone(self.token_login(token)) def test_mau_limits_parity(self) -> None: # Ensure we're not at the unix epoch. @@ -171,12 +202,10 @@ class AuthTestCase(unittest.HomeserverTestCase): ), ResourceLimitError, ) - self.get_failure( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ), - ResourceLimitError, + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNone(self.token_login(token)) # If in monthly active cohort self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( @@ -187,11 +216,10 @@ class AuthTestCase(unittest.HomeserverTestCase): self.user1, device_id=None, valid_until_ms=None ) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNotNone(self.token_login(token)) def test_mau_limits_not_exceeded(self) -> None: self.auth_blocking._limit_usage_by_mau = True @@ -209,14 +237,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) - ) - - def _get_macaroon(self) -> pymacaroons.Macaroon: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) - return pymacaroons.Macaroon.deserialize(token) + self.assertIsNotNone(self.token_login(token)) diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py index 32125f7bb7..40754a4711 100644 --- a/tests/util/test_macaroons.py +++ b/tests/util/test_macaroons.py @@ -84,34 +84,6 @@ class MacaroonGeneratorTestCase(TestCase): ) self.assertEqual(user_id, "@user:tesths") - def test_short_term_login_token(self): - """Test the generation and verification of short-term login tokens""" - token = self.macaroon_generator.generate_short_term_login_token( - user_id="@user:tesths", - auth_provider_id="oidc", - auth_provider_session_id="sid", - duration_in_ms=2 * 60 * 1000, - ) - - info = self.macaroon_generator.verify_short_term_login_token(token) - self.assertEqual(info.user_id, "@user:tesths") - self.assertEqual(info.auth_provider_id, "oidc") - self.assertEqual(info.auth_provider_session_id, "sid") - - # Raises with another secret key - with self.assertRaises(MacaroonVerificationFailedException): - self.other_macaroon_generator.verify_short_term_login_token(token) - - # Wait a minute - self.reactor.pump([60]) - # Shouldn't raise - self.macaroon_generator.verify_short_term_login_token(token) - # Wait another minute - self.reactor.pump([60]) - # Should raise since it expired - with self.assertRaises(MacaroonVerificationFailedException): - self.macaroon_generator.verify_short_term_login_token(token) - def test_oidc_session_token(self): """Test the generation and verification of OIDC session cookies""" state = "arandomstate" -- cgit 1.4.1 From 04fd6221de026a74e8a3e896796d39dcf5ac6e3b Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 26 Oct 2022 14:00:01 +0100 Subject: Fix incorrectly sending authentication tokens to application service as headers (#14301) --- changelog.d/14301.bugfix | 1 + synapse/appservice/api.py | 12 +++++++----- tests/appservice/test_api.py | 8 +++++--- 3 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14301.bugfix (limited to 'synapse') diff --git a/changelog.d/14301.bugfix b/changelog.d/14301.bugfix new file mode 100644 index 0000000000..668c1f3b9c --- /dev/null +++ b/changelog.d/14301.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0rc1 where access tokens would be incorrectly sent to application services as headers. Application services which were obtaining access tokens from query parameters were not affected. diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index fbac4375b0..60774b240d 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -123,7 +123,7 @@ class ApplicationServiceApi(SimpleHttpClient): response = await self.get_json( uri, {"access_token": service.hs_token}, - headers={"Authorization": f"Bearer {service.hs_token}"}, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if response is not None: # just an empty json object return True @@ -147,7 +147,7 @@ class ApplicationServiceApi(SimpleHttpClient): response = await self.get_json( uri, {"access_token": service.hs_token}, - headers={"Authorization": f"Bearer {service.hs_token}"}, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if response is not None: # just an empty json object return True @@ -190,7 +190,9 @@ class ApplicationServiceApi(SimpleHttpClient): b"access_token": service.hs_token, } response = await self.get_json( - uri, args=args, headers={"Authorization": f"Bearer {service.hs_token}"} + uri, + args=args, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if not isinstance(response, list): logger.warning( @@ -230,7 +232,7 @@ class ApplicationServiceApi(SimpleHttpClient): info = await self.get_json( uri, {"access_token": service.hs_token}, - headers={"Authorization": f"Bearer {service.hs_token}"}, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if not _is_valid_3pe_metadata(info): @@ -327,7 +329,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri=uri, json_body=body, args={"access_token": service.hs_token}, - headers={"Authorization": f"Bearer {service.hs_token}"}, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) if logger.isEnabledFor(logging.DEBUG): logger.debug( diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 11008ac1fb..89ee79396f 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Mapping +from typing import Any, List, Mapping, Sequence, Union from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -70,13 +70,15 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): self.request_url = None async def get_json( - url: str, args: Mapping[Any, Any], headers: Mapping[Any, Any] + url: str, + args: Mapping[Any, Any], + headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], ) -> List[JsonDict]: # Ensure the access token is passed as both a header and query arg. if not headers.get("Authorization") or not args.get(b"access_token"): raise RuntimeError("Access token not provided") - self.assertEqual(headers.get("Authorization"), f"Bearer {TOKEN}") + self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"]) self.assertEqual(args.get(b"access_token"), TOKEN) self.request_url = url if url == URL_USER: -- cgit 1.4.1 From 0cfbb3513152b8360155c2d75df50e06ea861fa4 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Date: Wed, 26 Oct 2022 18:51:23 +0400 Subject: fix broken avatar checks when server_name contains a port (#13927) Fixes check_avatar_size_and_mime_type() to successfully update avatars on homeservers running on non-default ports which it would mistakenly treat as remote homeserver while validating the avatar's size and mime type. Signed-off-by: Ashish Kumar ashfame@users.noreply.github.com --- changelog.d/13927.bugfix | 1 + synapse/handlers/profile.py | 6 +++++- tests/handlers/test_profile.py | 49 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 changelog.d/13927.bugfix (limited to 'synapse') diff --git a/changelog.d/13927.bugfix b/changelog.d/13927.bugfix new file mode 100644 index 0000000000..119cd128e7 --- /dev/null +++ b/changelog.d/13927.bugfix @@ -0,0 +1 @@ +Fix a bug which prevented setting an avatar on homeservers which have an explicit port in their `server_name` and have `max_avatar_size` and/or `allowed_avatar_mimetypes` configuration. Contributed by @ashfame. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index d8ff5289b5..4bf9a047a3 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -307,7 +307,11 @@ class ProfileHandler: if not self.max_avatar_size and not self.allowed_avatar_mimetypes: return True - server_name, _, media_id = parse_and_validate_mxc_uri(mxc) + host, port, media_id = parse_and_validate_mxc_uri(mxc) + if port is not None: + server_name = host + ":" + str(port) + else: + server_name = host if server_name == self.server_name: media_info = await self.store.get_local_media(media_id) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index f88c725a42..675aa023ac 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,6 +14,8 @@ from typing import Any, Awaitable, Callable, Dict from unittest.mock import Mock +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor import synapse.types @@ -327,6 +329,53 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertFalse(res) + @unittest.override_config( + {"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]} + ) + def test_avatar_constraint_on_local_server_with_port(self): + """Test that avatar metadata is correctly fetched when the media is on a local + server and the server has an explicit port. + + (This was previously a bug) + """ + local_server_name = self.hs.config.server.server_name + media_id = "local" + local_mxc = f"mxc://{local_server_name}/{media_id}" + + # mock up the existence of the avatar file + self._setup_local_files({media_id: {"mimetype": "image/png"}}) + + # and now check that check_avatar_size_and_mime_type is happy + self.assertTrue( + self.get_success(self.handler.check_avatar_size_and_mime_type(local_mxc)) + ) + + @parameterized.expand([("remote",), ("remote:1234",)]) + @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) + def test_check_avatar_on_remote_server(self, remote_server_name: str) -> None: + """Test that avatar metadata is correctly fetched from a remote server""" + media_id = "remote" + remote_mxc = f"mxc://{remote_server_name}/{media_id}" + + # if the media is remote, check_avatar_size_and_mime_type just checks the + # media cache, so we don't need to instantiate a real remote server. It is + # sufficient to poke an entry into the db. + self.get_success( + self.hs.get_datastores().main.store_cached_remote_media( + media_id=media_id, + media_type="image/png", + media_length=50, + origin=remote_server_name, + time_now_ms=self.clock.time_msec(), + upload_name=None, + filesystem_id="xyz", + ) + ) + + self.assertTrue( + self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc)) + ) + def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): """Stores metadata about files in the database. -- cgit 1.4.1 From 40fa8294e3096132819287dd0c6d6bd71a408902 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 26 Oct 2022 16:10:55 -0500 Subject: Refactor MSC3030 `/timestamp_to_event` to move away from our snowflake pull from `destination` pattern (#14096) 1. `federation_client.timestamp_to_event(...)` now handles all `destination` looping and uses our generic `_try_destination_list(...)` helper. 2. Consistently handling `NotRetryingDestination` and `FederationDeniedError` across `get_pdu` , backfill, and the generic `_try_destination_list` which is used for many places we use this pattern. 3. `get_pdu(...)` now returns `PulledPduInfo` so we know which `destination` we ended up pulling the PDU from --- changelog.d/14096.misc | 1 + synapse/federation/federation_client.py | 130 ++++++++++++++++++++++++----- synapse/handlers/federation.py | 15 ++-- synapse/handlers/federation_event.py | 31 ++++--- synapse/handlers/room.py | 126 +++++++++++----------------- synapse/util/retryutils.py | 2 +- tests/federation/test_federation_client.py | 12 ++- 7 files changed, 191 insertions(+), 126 deletions(-) create mode 100644 changelog.d/14096.misc (limited to 'synapse') diff --git a/changelog.d/14096.misc b/changelog.d/14096.misc new file mode 100644 index 0000000000..2c07dc673b --- /dev/null +++ b/changelog.d/14096.misc @@ -0,0 +1 @@ +Refactor [MSC3030](https://github.com/matrix-org/matrix-spec-proposals/pull/3030) `/timestamp_to_event` endpoint to loop over federation destinations with standard pattern and error handling. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b220ab43fc..fa225182be 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -80,6 +80,18 @@ PDU_RETRY_TIME_MS = 1 * 60 * 1000 T = TypeVar("T") +@attr.s(frozen=True, slots=True, auto_attribs=True) +class PulledPduInfo: + """ + A result object that stores the PDU and info about it like which homeserver we + pulled it from (`pull_origin`) + """ + + pdu: EventBase + # Which homeserver we pulled the PDU from + pull_origin: str + + class InvalidResponseError(RuntimeError): """Helper for _try_destination_list: indicates that the server returned a response we couldn't parse @@ -114,7 +126,9 @@ class FederationClient(FederationBase): self.hostname = hs.hostname self.signing_key = hs.signing_key - self._get_pdu_cache: ExpiringCache[str, EventBase] = ExpiringCache( + # Cache mapping `event_id` to a tuple of the event itself and the `pull_origin` + # (which server we pulled the event from) + self._get_pdu_cache: ExpiringCache[str, Tuple[EventBase, str]] = ExpiringCache( cache_name="get_pdu_cache", clock=self._clock, max_len=1000, @@ -352,11 +366,11 @@ class FederationClient(FederationBase): @tag_args async def get_pdu( self, - destinations: Iterable[str], + destinations: Collection[str], event_id: str, room_version: RoomVersion, timeout: Optional[int] = None, - ) -> Optional[EventBase]: + ) -> Optional[PulledPduInfo]: """Requests the PDU with given origin and ID from the remote home servers. @@ -371,11 +385,11 @@ class FederationClient(FederationBase): moving to the next destination. None indicates no timeout. Returns: - The requested PDU, or None if we were unable to find it. + The requested PDU wrapped in `PulledPduInfo`, or None if we were unable to find it. """ logger.debug( - "get_pdu: event_id=%s from destinations=%s", event_id, destinations + "get_pdu(event_id=%s): from destinations=%s", event_id, destinations ) # TODO: Rate limit the number of times we try and get the same event. @@ -384,19 +398,25 @@ class FederationClient(FederationBase): # it gets persisted to the database), so we cache the results of the lookup. # Note that this is separate to the regular get_event cache which caches # events once they have been persisted. - event = self._get_pdu_cache.get(event_id) + get_pdu_cache_entry = self._get_pdu_cache.get(event_id) + event = None + pull_origin = None + if get_pdu_cache_entry: + event, pull_origin = get_pdu_cache_entry # If we don't see the event in the cache, go try to fetch it from the # provided remote federated destinations - if not event: + else: pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) + # TODO: We can probably refactor this to use `_try_destination_list` for destination in destinations: now = self._clock.time_msec() last_attempt = pdu_attempts.get(destination, 0) if last_attempt + PDU_RETRY_TIME_MS > now: logger.debug( - "get_pdu: skipping destination=%s because we tried it recently last_attempt=%s and we only check every %s (now=%s)", + "get_pdu(event_id=%s): skipping destination=%s because we tried it recently last_attempt=%s and we only check every %s (now=%s)", + event_id, destination, last_attempt, PDU_RETRY_TIME_MS, @@ -411,43 +431,48 @@ class FederationClient(FederationBase): room_version=room_version, timeout=timeout, ) + pull_origin = destination pdu_attempts[destination] = now if event: # Prime the cache - self._get_pdu_cache[event.event_id] = event + self._get_pdu_cache[event.event_id] = (event, pull_origin) # Now that we have an event, we can break out of this # loop and stop asking other destinations. break + except NotRetryingDestination as e: + logger.info("get_pdu(event_id=%s): %s", event_id, e) + continue + except FederationDeniedError: + logger.info( + "get_pdu(event_id=%s): Not attempting to fetch PDU from %s because the homeserver is not on our federation whitelist", + event_id, + destination, + ) + continue except SynapseError as e: logger.info( - "Failed to get PDU %s from %s because %s", + "get_pdu(event_id=%s): Failed to get PDU from %s because %s", event_id, destination, e, ) continue - except NotRetryingDestination as e: - logger.info(str(e)) - continue - except FederationDeniedError as e: - logger.info(str(e)) - continue except Exception as e: pdu_attempts[destination] = now logger.info( - "Failed to get PDU %s from %s because %s", + "get_pdu(event_id=): Failed to get PDU from %s because %s", event_id, destination, e, ) continue - if not event: + if not event or not pull_origin: return None # `event` now refers to an object stored in `get_pdu_cache`. Our @@ -459,7 +484,7 @@ class FederationClient(FederationBase): event.room_version, ) - return event_copy + return PulledPduInfo(event_copy, pull_origin) @trace @tag_args @@ -699,12 +724,14 @@ class FederationClient(FederationBase): pdu_origin = get_domain_from_id(pdu.sender) if not res and pdu_origin != origin: try: - res = await self.get_pdu( + pulled_pdu_info = await self.get_pdu( destinations=[pdu_origin], event_id=pdu.event_id, room_version=room_version, timeout=10000, ) + if pulled_pdu_info is not None: + res = pulled_pdu_info.pdu except SynapseError: pass @@ -806,6 +833,7 @@ class FederationClient(FederationBase): ) for destination in destinations: + # We don't want to ask our own server for information we don't have if destination == self.server_name: continue @@ -814,9 +842,21 @@ class FederationClient(FederationBase): except ( RequestSendFailed, InvalidResponseError, - NotRetryingDestination, ) as e: logger.warning("Failed to %s via %s: %s", description, destination, e) + # Skip to the next homeserver in the list to try. + continue + except NotRetryingDestination as e: + logger.info("%s: %s", description, e) + continue + except FederationDeniedError: + logger.info( + "%s: Not attempting to %s from %s because the homeserver is not on our federation whitelist", + description, + description, + destination, + ) + continue except UnsupportedRoomVersionError: raise except HttpResponseException as e: @@ -1609,6 +1649,54 @@ class FederationClient(FederationBase): return result async def timestamp_to_event( + self, *, destinations: List[str], room_id: str, timestamp: int, direction: str + ) -> Optional["TimestampToEventResponse"]: + """ + Calls each remote federating server from `destinations` asking for their closest + event to the given timestamp in the given direction until we get a response. + Also validates the response to always return the expected keys or raises an + error. + + Args: + destinations: The domains of homeservers to try fetching from + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + A parsed TimestampToEventResponse including the closest event_id + and origin_server_ts or None if no destination has a response. + """ + + async def _timestamp_to_event_from_destination( + destination: str, + ) -> TimestampToEventResponse: + return await self._timestamp_to_event_from_destination( + destination, room_id, timestamp, direction + ) + + try: + # Loop through each homeserver candidate until we get a succesful response + timestamp_to_event_response = await self._try_destination_list( + "timestamp_to_event", + destinations, + # TODO: The requested timestamp may lie in a part of the + # event graph that the remote server *also* didn't have, + # in which case they will have returned another event + # which may be nowhere near the requested timestamp. In + # the future, we may need to reconcile that gap and ask + # other homeservers, and/or extend `/timestamp_to_event` + # to return events on *both* sides of the timestamp to + # help reconcile the gap faster. + _timestamp_to_event_from_destination, + ) + return timestamp_to_event_response + except SynapseError: + return None + + async def _timestamp_to_event_from_destination( self, destination: str, room_id: str, timestamp: int, direction: str ) -> "TimestampToEventResponse": """ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4fbc79a6cb..5fc3b8bc8c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -442,6 +442,15 @@ class FederationHandler: # appropriate stuff. # TODO: We can probably do something more intelligent here. return True + except NotRetryingDestination as e: + logger.info("_maybe_backfill_inner: %s", e) + continue + except FederationDeniedError: + logger.info( + "_maybe_backfill_inner: Not attempting to backfill from %s because the homeserver is not on our federation whitelist", + dom, + ) + continue except (SynapseError, InvalidResponseError) as e: logger.info("Failed to backfill from %s because %s", dom, e) continue @@ -477,15 +486,9 @@ class FederationHandler: logger.info("Failed to backfill from %s because %s", dom, e) continue - except NotRetryingDestination as e: - logger.info(str(e)) - continue except RequestSendFailed as e: logger.info("Failed to get backfill from %s because %s", dom, e) continue - except FederationDeniedError as e: - logger.info(e) - continue except Exception as e: logger.exception("Failed to backfill from %s because %s", dom, e) continue diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 7da6316a82..9ca5df7c78 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -58,7 +58,7 @@ from synapse.event_auth import ( ) from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.federation.federation_client import InvalidResponseError +from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo from synapse.logging.context import nested_logging_context from synapse.logging.opentracing import ( SynapseTags, @@ -1517,8 +1517,8 @@ class FederationEventHandler: ) async def backfill_event_id( - self, destination: str, room_id: str, event_id: str - ) -> EventBase: + self, destinations: List[str], room_id: str, event_id: str + ) -> PulledPduInfo: """Backfill a single event and persist it as a non-outlier which means we also pull in all of the state and auth events necessary for it. @@ -1530,24 +1530,21 @@ class FederationEventHandler: Raises: FederationError if we are unable to find the event from the destination """ - logger.info( - "backfill_event_id: event_id=%s from destination=%s", event_id, destination - ) + logger.info("backfill_event_id: event_id=%s", event_id) room_version = await self._store.get_room_version(room_id) - event_from_response = await self._federation_client.get_pdu( - [destination], + pulled_pdu_info = await self._federation_client.get_pdu( + destinations, event_id, room_version, ) - if not event_from_response: + if not pulled_pdu_info: raise FederationError( "ERROR", 404, - "Unable to find event_id=%s from destination=%s to backfill." - % (event_id, destination), + f"Unable to find event_id={event_id} from remote servers to backfill.", affected=event_id, ) @@ -1555,13 +1552,13 @@ class FederationEventHandler: # and auth events to de-outlier it. This also sets up the necessary # `state_groups` for the event. await self._process_pulled_events( - destination, - [event_from_response], + pulled_pdu_info.pull_origin, + [pulled_pdu_info.pdu], # Prevent notifications going to clients backfilled=True, ) - return event_from_response + return pulled_pdu_info @trace @tag_args @@ -1584,19 +1581,19 @@ class FederationEventHandler: async def get_event(event_id: str) -> None: with nested_logging_context(event_id): try: - event = await self._federation_client.get_pdu( + pulled_pdu_info = await self._federation_client.get_pdu( [destination], event_id, room_version, ) - if event is None: + if pulled_pdu_info is None: logger.warning( "Server %s didn't return event %s", destination, event_id, ) return - events.append(event) + events.append(pulled_pdu_info.pdu) except Exception as e: logger.warning( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index cc1e5c8f97..de97886ea9 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -49,7 +49,6 @@ from synapse.api.constants import ( from synapse.api.errors import ( AuthError, Codes, - HttpResponseException, LimitExceededError, NotFoundError, StoreError, @@ -60,7 +59,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.utils import copy_and_fixup_power_levels_contents -from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin @@ -1472,7 +1470,12 @@ class TimestampLookupHandler: Raises: SynapseError if unable to find any event locally in the given direction """ - + logger.debug( + "get_event_for_timestamp(room_id=%s, timestamp=%s, direction=%s) Finding closest event...", + room_id, + timestamp, + direction, + ) local_event_id = await self.store.get_event_id_for_timestamp( room_id, timestamp, direction ) @@ -1524,85 +1527,54 @@ class TimestampLookupHandler: ) ) - # Loop through each homeserver candidate until we get a succesful response - for domain in likely_domains: - # We don't want to ask our own server for information we don't have - if domain == self.server_name: - continue + remote_response = await self.federation_client.timestamp_to_event( + destinations=likely_domains, + room_id=room_id, + timestamp=timestamp, + direction=direction, + ) + if remote_response is not None: + logger.debug( + "get_event_for_timestamp: remote_response=%s", + remote_response, + ) - try: - remote_response = await self.federation_client.timestamp_to_event( - domain, room_id, timestamp, direction - ) - logger.debug( - "get_event_for_timestamp: response from domain(%s)=%s", - domain, - remote_response, - ) + remote_event_id = remote_response.event_id + remote_origin_server_ts = remote_response.origin_server_ts - remote_event_id = remote_response.event_id - remote_origin_server_ts = remote_response.origin_server_ts - - # Backfill this event so we can get a pagination token for - # it with `/context` and paginate `/messages` from this - # point. - # - # TODO: The requested timestamp may lie in a part of the - # event graph that the remote server *also* didn't have, - # in which case they will have returned another event - # which may be nowhere near the requested timestamp. In - # the future, we may need to reconcile that gap and ask - # other homeservers, and/or extend `/timestamp_to_event` - # to return events on *both* sides of the timestamp to - # help reconcile the gap faster. - remote_event = ( - await self.federation_event_handler.backfill_event_id( - domain, room_id, remote_event_id - ) - ) + # Backfill this event so we can get a pagination token for + # it with `/context` and paginate `/messages` from this + # point. + pulled_pdu_info = await self.federation_event_handler.backfill_event_id( + likely_domains, room_id, remote_event_id + ) + remote_event = pulled_pdu_info.pdu - # XXX: When we see that the remote server is not trustworthy, - # maybe we should not ask them first in the future. - if remote_origin_server_ts != remote_event.origin_server_ts: - logger.info( - "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.", - domain, - remote_event_id, - remote_origin_server_ts, - remote_event.origin_server_ts, - ) - - # Only return the remote event if it's closer than the local event - if not local_event or ( - abs(remote_event.origin_server_ts - timestamp) - < abs(local_event.origin_server_ts - timestamp) - ): - logger.info( - "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)", - remote_event_id, - remote_event.origin_server_ts, - timestamp, - local_event.event_id if local_event else None, - local_event.origin_server_ts if local_event else None, - ) - return remote_event_id, remote_origin_server_ts - except (HttpResponseException, InvalidResponseError) as ex: - # Let's not put a high priority on some other homeserver - # failing to respond or giving a random response - logger.debug( - "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", - domain, - type(ex).__name__, - ex, - ex.args, + # XXX: When we see that the remote server is not trustworthy, + # maybe we should not ask them first in the future. + if remote_origin_server_ts != remote_event.origin_server_ts: + logger.info( + "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.", + pulled_pdu_info.pull_origin, + remote_event_id, + remote_origin_server_ts, + remote_event.origin_server_ts, ) - except Exception: - # But we do want to see some exceptions in our code - logger.warning( - "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception", - domain, - exc_info=True, + + # Only return the remote event if it's closer than the local event + if not local_event or ( + abs(remote_event.origin_server_ts - timestamp) + < abs(local_event.origin_server_ts - timestamp) + ): + logger.info( + "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)", + remote_event_id, + remote_event.origin_server_ts, + timestamp, + local_event.event_id if local_event else None, + local_event.origin_server_ts if local_event else None, ) + return remote_event_id, remote_origin_server_ts # To appease mypy, we have to add both of these conditions to check for # `None`. We only expect `local_event` to be `None` when diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index d0a69ff843..dcc037b982 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -51,7 +51,7 @@ class NotRetryingDestination(Exception): destination: the domain in question """ - msg = "Not retrying server %s." % (destination,) + msg = f"Not retrying server {destination} because we tried it recently retry_last_ts={retry_last_ts} and we won't check for another retry_interval={retry_interval}ms." super().__init__(msg) self.retry_last_ts = retry_last_ts diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index 51d3bb8fff..e67f405826 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -142,14 +142,14 @@ class FederationClientTest(FederatingHomeserverTestCase): def test_get_pdu_returns_nothing_when_event_does_not_exist(self): """No event should be returned when the event does not exist""" - remote_pdu = self.get_success( + pulled_pdu_info = self.get_success( self.hs.get_federation_client().get_pdu( ["yet.another.server"], "event_should_not_exist", RoomVersions.V9, ) ) - self.assertEqual(remote_pdu, None) + self.assertEqual(pulled_pdu_info, None) def test_get_pdu(self): """Test to make sure an event is returned by `get_pdu()`""" @@ -169,13 +169,15 @@ class FederationClientTest(FederatingHomeserverTestCase): remote_pdu.internal_metadata.outlier = True # Get the event again. This time it should read it from cache. - remote_pdu2 = self.get_success( + pulled_pdu_info2 = self.get_success( self.hs.get_federation_client().get_pdu( ["yet.another.server"], remote_pdu.event_id, RoomVersions.V9, ) ) + self.assertIsNotNone(pulled_pdu_info2) + remote_pdu2 = pulled_pdu_info2.pdu # Sanity check that we are working against the same event self.assertEqual(remote_pdu.event_id, remote_pdu2.event_id) @@ -215,13 +217,15 @@ class FederationClientTest(FederatingHomeserverTestCase): ) ) - remote_pdu = self.get_success( + pulled_pdu_info = self.get_success( self.hs.get_federation_client().get_pdu( ["yet.another.server"], "event_id", RoomVersions.V9, ) ) + self.assertIsNotNone(pulled_pdu_info) + remote_pdu = pulled_pdu_info.pdu # check the right call got made to the agent self._mock_agent.request.assert_called_once_with( -- cgit 1.4.1 From cbe01ccc3f9c09a0a7233f90200fbcb8ae5245cf Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 27 Oct 2022 10:52:23 +0100 Subject: Reject history insertion during partial joins (#14291) --- changelog.d/14291.bugfix | 1 + synapse/rest/client/room_batch.py | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 changelog.d/14291.bugfix (limited to 'synapse') diff --git a/changelog.d/14291.bugfix b/changelog.d/14291.bugfix new file mode 100644 index 0000000000..bac5065e94 --- /dev/null +++ b/changelog.d/14291.bugfix @@ -0,0 +1 @@ +Prevent history insertion ([MSC2716](https://github.com/matrix-org/matrix-spec-proposals/pull/2716)) during an partial join ([MSC3706](https://github.com/matrix-org/matrix-spec-proposals/pull/3706)). diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index dd91dabedd..10be4a781b 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -108,6 +108,13 @@ class RoomBatchSendEventRestServlet(RestServlet): errcode=Codes.MISSING_PARAM, ) + if await self.store.is_partial_state_room(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Cannot insert history batches until we have fully joined the room", + errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE, + ) + # Verify the batch_id_from_query corresponds to an actual insertion event # and have the batch connected. if batch_id_from_query: -- cgit 1.4.1 From 4dc05f30193935224103e8772b1bbc15293e5cb6 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Thu, 27 Oct 2022 14:16:00 +0200 Subject: Fix presence bug introduced in 1.64 by #13313 (#14243) * Fix presence bug introduced in 1.64 by #13313 Signed-off-by: Mathieu Velten * Add changelog * Add DISTINCT * Apply suggestions from code review Signed-off-by: Mathieu Velten --- changelog.d/14243.bugfix | 1 + synapse/storage/databases/main/roommember.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14243.bugfix (limited to 'synapse') diff --git a/changelog.d/14243.bugfix b/changelog.d/14243.bugfix new file mode 100644 index 0000000000..ac0b21c2c5 --- /dev/null +++ b/changelog.d/14243.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.64.0 where presence updates could be missing from `/sync` responses. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 32e1e983a5..ab708b0ba5 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -742,7 +742,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # user and the set of other users, and then checking if there is any # overlap. sql = f""" - SELECT b.state_key + SELECT DISTINCT b.state_key FROM ( SELECT room_id FROM current_state_events WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ? @@ -751,7 +751,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): SELECT room_id, state_key FROM current_state_events WHERE type = 'm.room.member' AND membership = 'join' AND {clause} ) AS b using (room_id) - LIMIT 1 """ txn.execute(sql, (user_id, *args)) -- cgit 1.4.1 From 1357ae869f279a3f0855c1b1c2750eca2887928e Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 27 Oct 2022 15:39:47 +0200 Subject: Add workers settings to configuration manual (#14086) * Add workers settings to configuration manual * Update `pusher_instances` * update url to python logger * update headlines * update links after headline change * remove link from `daemon process` There is no docs in Synapse for this * extend example for `federation_sender_instances` and `pusher_instances` * more infos about stream writers * add link to DAG * update `pusher_instances` * update `worker_listeners` * update `stream_writers` * Update `worker_name` Co-authored-by: David Robertson --- changelog.d/14086.doc | 1 + docs/sample_log_config.yaml | 2 +- docs/usage/configuration/config_documentation.md | 268 +++++++++++++++++++---- docs/workers.md | 100 ++++++--- synapse/config/logger.py | 2 +- 5 files changed, 291 insertions(+), 82 deletions(-) create mode 100644 changelog.d/14086.doc (limited to 'synapse') diff --git a/changelog.d/14086.doc b/changelog.d/14086.doc new file mode 100644 index 0000000000..5b4b938759 --- /dev/null +++ b/changelog.d/14086.doc @@ -0,0 +1 @@ +Add workers settings to [configuration manual](https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#individual-worker-configuration). \ No newline at end of file diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index 3065a0e2d9..6339160d00 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -6,7 +6,7 @@ # Synapse also supports structured logging for machine readable logs which can # be ingested by ELK stacks. See [2] for details. # -# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema +# [1]: https://docs.python.org/3/library/logging.config.html#configuration-dictionary-schema # [2]: https://matrix-org.github.io/synapse/latest/structured_logging.html version: 1 diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index d81eda52c1..fb5eb42c52 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -99,7 +99,7 @@ modules: config: {} ``` --- -## Server ## +## Server Define your homeserver name and other base options. @@ -159,7 +159,7 @@ including _matrix/...). This is the same URL a user might enter into the 'Custom Homeserver URL' field on their client. If you use Synapse with a reverse proxy, this should be the URL to reach Synapse via the proxy. Otherwise, it should be the URL to reach Synapse's client HTTP listener (see -'listeners' below). +['listeners'](#listeners) below). Defaults to `https:///`. @@ -570,7 +570,7 @@ Example configuration: delete_stale_devices_after: 1y ``` -## Homeserver blocking ## +## Homeserver blocking Useful options for Synapse admins. --- @@ -922,7 +922,7 @@ retention: interval: 1d ``` --- -## TLS ## +## TLS Options related to TLS. @@ -1012,7 +1012,7 @@ federation_custom_ca_list: - myCA3.pem ``` --- -## Federation ## +## Federation Options related to federation. @@ -1071,7 +1071,7 @@ Example configuration: allow_device_name_lookup_over_federation: true ``` --- -## Caching ## +## Caching Options related to caching. @@ -1185,7 +1185,7 @@ file in Synapse's `contrib` directory, you can send a `SIGHUP` signal by using `systemctl reload matrix-synapse`. --- -## Database ## +## Database Config options related to database settings. --- @@ -1332,20 +1332,21 @@ databases: cp_max: 10 ``` --- -## Logging ## +## Logging Config options related to logging. --- ### `log_config` -This option specifies a yaml python logging config file as described [here](https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema). +This option specifies a yaml python logging config file as described +[here](https://docs.python.org/3/library/logging.config.html#configuration-dictionary-schema). Example configuration: ```yaml log_config: "CONFDIR/SERVERNAME.log.config" ``` --- -## Ratelimiting ## +## Ratelimiting Options related to ratelimiting in Synapse. Each ratelimiting configuration is made of two parameters: @@ -1576,7 +1577,7 @@ Example configuration: federation_rr_transactions_per_room_per_second: 40 ``` --- -## Media Store ## +## Media Store Config options related to Synapse's media store. --- @@ -1766,7 +1767,7 @@ url_preview_ip_range_blacklist: - 'ff00::/8' - 'fec0::/10' ``` ----- +--- ### `url_preview_ip_range_whitelist` This option sets a list of IP address CIDR ranges that the URL preview spider is allowed @@ -1860,7 +1861,7 @@ Example configuration: - 'fr;q=0.8' - '*;q=0.7' ``` ----- +--- ### `oembed` oEmbed allows for easier embedding content from a website. It can be @@ -1877,7 +1878,7 @@ oembed: - oembed/my_providers.json ``` --- -## Captcha ## +## Captcha See [here](../../CAPTCHA_SETUP.md) for full details on setting up captcha. @@ -1926,7 +1927,7 @@ Example configuration: recaptcha_siteverify_api: "https://my.recaptcha.site" ``` --- -## TURN ## +## TURN Options related to adding a TURN server to Synapse. --- @@ -1947,7 +1948,7 @@ Example configuration: ```yaml turn_shared_secret: "YOUR_SHARED_SECRET" ``` ----- +--- ### `turn_username` and `turn_password` The Username and password if the TURN server needs them and does not use a token. @@ -2366,7 +2367,7 @@ Example configuration: ```yaml session_lifetime: 24h ``` ----- +--- ### `refresh_access_token_lifetime` Time that an access token remains valid for, if the session is using refresh tokens. @@ -2422,7 +2423,7 @@ nonrefreshable_access_token_lifetime: 24h ``` --- -## Metrics ### +## Metrics Config options related to metrics. --- @@ -2519,7 +2520,7 @@ Example configuration: report_stats_endpoint: https://example.com/report-usage-stats/push ``` --- -## API Configuration ## +## API Configuration Config settings related to the client/server API --- @@ -2619,7 +2620,7 @@ Example configuration: form_secret: ``` --- -## Signing Keys ## +## Signing Keys Config options relating to signing keys --- @@ -2728,7 +2729,7 @@ Example configuration: key_server_signing_keys_path: "key_server_signing_keys.key" ``` --- -## Single sign-on integration ## +## Single sign-on integration The following settings can be used to make Synapse use a single sign-on provider for authentication, instead of its internal password database. @@ -3348,7 +3349,7 @@ email: email_validation: "[%(server_name)s] Validate your email" ``` --- -## Push ## +## Push Configuration settings related to push notifications --- @@ -3381,7 +3382,7 @@ push: group_unread_count_by_room: false ``` --- -## Rooms ## +## Rooms Config options relating to rooms. --- @@ -3627,7 +3628,7 @@ default_power_level_content_override: ``` --- -## Opentracing ## +## Opentracing Configuration options related to Opentracing support. --- @@ -3670,14 +3671,71 @@ opentracing: false ``` --- -## Workers ## -Configuration options related to workers. +## Coordinating workers +Configuration options related to workers which belong in the main config file +(usually called `homeserver.yaml`). +A Synapse deployment can scale horizontally by running multiple Synapse processes +called _workers_. Incoming requests are distributed between workers to handle higher +loads. Some workers are privileged and can accept requests from other workers. + +As a result, the worker configuration is divided into two parts. + +1. The first part (in this section of the manual) defines which shardable tasks + are delegated to privileged workers. This allows unprivileged workers to make + request a privileged worker to act on their behalf. +1. [The second part](#individual-worker-configuration) + controls the behaviour of individual workers in isolation. + +For guidance on setting up workers, see the [worker documentation](../../workers.md). + +--- +### `worker_replication_secret` + +A shared secret used by the replication APIs on the main process to authenticate +HTTP requests from workers. + +The default, this value is omitted (equivalently `null`), which means that +traffic between the workers and the main process is not authenticated. + +Example configuration: +```yaml +worker_replication_secret: "secret_secret" +``` +--- +### `start_pushers` + +Controls sending of push notifications on the main process. Set to `false` +if using a [pusher worker](../../workers.md#synapseapppusher). Defaults to `true`. + +Example configuration: +```yaml +start_pushers: false +``` +--- +### `pusher_instances` + +It is possible to run multiple [pusher workers](../../workers.md#synapseapppusher), +in which case the work is balanced across them. Use this setting to list the pushers by +[`worker_name`](#worker_name). Ensure the main process and all pusher workers are +restarted after changing this option. +If no or only one pusher worker is configured, this setting is not necessary. +The main process will send out push notifications by default if you do not disable +it by setting [`start_pushers: false`](#start_pushers). + +Example configuration: +```yaml +start_pushers: false +pusher_instances: + - pusher_worker1 + - pusher_worker2 +``` --- ### `send_federation` Controls sending of outbound federation transactions on the main process. -Set to false if using a federation sender worker. Defaults to true. +Set to `false` if using a [federation sender worker](../../workers.md#synapseappfederation_sender). +Defaults to `true`. Example configuration: ```yaml @@ -3686,8 +3744,9 @@ send_federation: false --- ### `federation_sender_instances` -It is possible to run multiple federation sender workers, in which case the -work is balanced across them. Use this setting to list the senders. +It is possible to run multiple +[federation sender worker](../../workers.md#synapseappfederation_sender), in which +case the work is balanced across them. Use this setting to list the senders. This configuration setting must be shared between all federation sender workers, and if changed all federation sender workers must be stopped at the same time and then @@ -3696,14 +3755,19 @@ events may be dropped). Example configuration: ```yaml +send_federation: false federation_sender_instances: - federation_sender1 ``` --- ### `instance_map` -When using workers this should be a map from worker name to the +When using workers this should be a map from [`worker_name`](#worker_name) to the HTTP replication listener of the worker, if configured. +Each worker declared under [`stream_writers`](../../workers.md#stream-writers) needs +a HTTP replication listener, and that listener should be included in the `instance_map`. +(The main process also needs an HTTP replication listener, but it should not be +listed in the `instance_map`.) Example configuration: ```yaml @@ -3716,8 +3780,11 @@ instance_map: ### `stream_writers` Experimental: When using workers you can define which workers should -handle event persistence and typing notifications. Any worker -specified here must also be in the `instance_map`. +handle writing to streams such as event persistence and typing notifications. +Any worker specified here must also be in the [`instance_map`](#instance_map). + +See the list of available streams in the +[worker documentation](../../workers.md#stream-writers). Example configuration: ```yaml @@ -3728,29 +3795,18 @@ stream_writers: --- ### `run_background_tasks_on` -The worker that is used to run background tasks (e.g. cleaning up expired -data). If not provided this defaults to the main process. +The [worker](../../workers.md#background-tasks) that is used to run +background tasks (e.g. cleaning up expired data). If not provided this +defaults to the main process. Example configuration: ```yaml run_background_tasks_on: worker1 ``` --- -### `worker_replication_secret` - -A shared secret used by the replication APIs to authenticate HTTP requests -from workers. - -By default this is unused and traffic is not authenticated. - -Example configuration: -```yaml -worker_replication_secret: "secret_secret" -``` ### `redis` -Configuration for Redis when using workers. This *must* be enabled when -using workers (unless using old style direct TCP configuration). +Configuration for Redis when using workers. This *must* be enabled when using workers. This setting has the following sub-options: * `enabled`: whether to use Redis support. Defaults to false. * `host` and `port`: Optional host and port to use to connect to redis. Defaults to @@ -3765,7 +3821,123 @@ redis: port: 6379 password: ``` -## Background Updates ## +--- +## Individual worker configuration +These options configure an individual worker, in its worker configuration file. +They should be not be provided when configuring the main process. + +Note also the configuration above for +[coordinating a cluster of workers](#coordinating-workers). + +For guidance on setting up workers, see the [worker documentation](../../workers.md). + +--- +### `worker_app` + +The type of worker. The currently available worker applications are listed +in [worker documentation](../../workers.md#available-worker-applications). + +The most common worker is the +[`synapse.app.generic_worker`](../../workers.md#synapseappgeneric_worker). + +Example configuration: +```yaml +worker_app: synapse.app.generic_worker +``` +--- +### `worker_name` + +A unique name for the worker. The worker needs a name to be addressed in +further parameters and identification in log files. We strongly recommend +giving each worker a unique `worker_name`. + +Example configuration: +```yaml +worker_name: generic_worker1 +``` +--- +### `worker_replication_host` + +The HTTP replication endpoint that it should talk to on the main Synapse process. +The main Synapse process defines this with a `replication` resource in +[`listeners` option](#listeners). + +Example configuration: +```yaml +worker_replication_host: 127.0.0.1 +``` +--- +### `worker_replication_http_port` + +The HTTP replication port that it should talk to on the main Synapse process. +The main Synapse process defines this with a `replication` resource in +[`listeners` option](#listeners). + +Example configuration: +```yaml +worker_replication_http_port: 9093 +``` +--- +### `worker_listeners` + +A worker can handle HTTP requests. To do so, a `worker_listeners` option +must be declared, in the same way as the [`listeners` option](#listeners) +in the shared config. + +Workers declared in [`stream_writers`](#stream_writers) will need to include a +`replication` listener here, in order to accept internal HTTP requests from +other workers. + +Example configuration: +```yaml +worker_listeners: + - type: http + port: 8083 + resources: + - names: [client, federation] +``` +--- +### `worker_daemonize` + +Specifies whether the worker should be started as a daemon process. +If Synapse is being managed by [systemd](../../systemd-with-workers/README.md), this option +must be omitted or set to `false`. + +Defaults to `false`. + +Example configuration: +```yaml +worker_daemonize: true +``` +--- +### `worker_pid_file` + +When running a worker as a daemon, we need a place to store the +[PID](https://en.wikipedia.org/wiki/Process_identifier) of the worker. +This option defines the location of that "pid file". + +This option is required if `worker_daemonize` is `true` and ignored +otherwise. It has no default. + +See also the [`pid_file` option](#pid_file) option for the main Synapse process. + +Example configuration: +```yaml +worker_pid_file: DATADIR/generic_worker1.pid +``` +--- +### `worker_log_config` + +This option specifies a yaml python logging config file as described +[here](https://docs.python.org/3/library/logging.config.html#configuration-dictionary-schema). +See also the [`log_config` option](#log_config) option for the main Synapse process. + +Example configuration: +```yaml +worker_log_config: /etc/matrix-synapse/generic-worker-log.yaml +``` +--- +## Background Updates Configuration settings related to background updates. --- diff --git a/docs/workers.md b/docs/workers.md index c27b3f8bd5..5e1b9ba220 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -88,10 +88,12 @@ shared configuration file. ### Shared configuration Normally, only a couple of changes are needed to make an existing configuration -file suitable for use with workers. First, you need to enable an "HTTP replication -listener" for the main process; and secondly, you need to enable redis-based -replication. Optionally, a shared secret can be used to authenticate HTTP -traffic between workers. For example: +file suitable for use with workers. First, you need to enable an +["HTTP replication listener"](usage/configuration/config_documentation.md#listeners) +for the main process; and secondly, you need to enable +[redis-based replication](usage/configuration/config_documentation.md#redis). +Optionally, a [shared secret](usage/configuration/config_documentation.md#worker_replication_secret) +can be used to authenticate HTTP traffic between workers. For example: ```yaml # extend the existing `listeners` section. This defines the ports that the @@ -111,25 +113,28 @@ redis: enabled: true ``` -See the [configuration manual](usage/configuration/config_documentation.html) for the full documentation of each option. +See the [configuration manual](usage/configuration/config_documentation.md) +for the full documentation of each option. Under **no circumstances** should the replication listener be exposed to the public internet; replication traffic is: * always unencrypted -* unauthenticated, unless `worker_replication_secret` is configured +* unauthenticated, unless [`worker_replication_secret`](usage/configuration/config_documentation.md#worker_replication_secret) + is configured ### Worker configuration In the config file for each worker, you must specify: - * The type of worker (`worker_app`). The currently available worker applications are listed below. - * A unique name for the worker (`worker_name`). + * The type of worker ([`worker_app`](usage/configuration/config_documentation.md#worker_app)). + The currently available worker applications are listed [below](#available-worker-applications). + * A unique name for the worker ([`worker_name`](usage/configuration/config_documentation.md#worker_name)). * The HTTP replication endpoint that it should talk to on the main synapse process - (`worker_replication_host` and `worker_replication_http_port`) - * If handling HTTP requests, a `worker_listeners` option with an `http` - listener, in the same way as the [`listeners`](usage/configuration/config_documentation.md#listeners) - option in the shared config. + ([`worker_replication_host`](usage/configuration/config_documentation.md#worker_replication_host) and + [`worker_replication_http_port`](usage/configuration/config_documentation.md#worker_replication_http_port)). + * If handling HTTP requests, a [`worker_listeners`](usage/configuration/config_documentation.md#worker_listeners) option + with an `http` listener. * If handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for the main process (`worker_main_http_uri`). @@ -146,7 +151,6 @@ plain HTTP endpoint on port 8083 separately serving various endpoints, e.g. Obviously you should configure your reverse-proxy to route the relevant endpoints to the worker (`localhost:8083` in the above example). - ### Running Synapse with workers Finally, you need to start your worker processes. This can be done with either @@ -288,7 +292,8 @@ For multiple workers not handling the SSO endpoints properly, see [#9427](https://github.com/matrix-org/synapse/issues/9427). Note that a [HTTP listener](usage/configuration/config_documentation.md#listeners) -with `client` and `federation` `resources` must be configured in the `worker_listeners` +with `client` and `federation` `resources` must be configured in the +[`worker_listeners`](usage/configuration/config_documentation.md#worker_listeners) option in the worker config. #### Load balancing @@ -331,9 +336,10 @@ of the main process to a particular worker. To enable this, the worker must have a [HTTP `replication` listener](usage/configuration/config_documentation.md#listeners) configured, -have a `worker_name` and be listed in the `instance_map` config. The same worker -can handle multiple streams, but unless otherwise documented, each stream can only -have a single writer. +have a [`worker_name`](usage/configuration/config_documentation.md#worker_name) +and be listed in the [`instance_map`](usage/configuration/config_documentation.md#instance_map) +config. The same worker can handle multiple streams, but unless otherwise documented, +each stream can only have a single writer. For example, to move event persistence off to a dedicated worker, the shared configuration would include: @@ -360,9 +366,26 @@ streams and the endpoints associated with them: ##### The `events` stream -The `events` stream experimentally supports having multiple writers, where work -is sharded between them by room ID. Note that you *must* restart all worker -instances when adding or removing event persisters. An example `stream_writers` +The `events` stream experimentally supports having multiple writer workers, where load +is sharded between them by room ID. Each writer is called an _event persister_. They are +responsible for +- receiving new events, +- linking them to those already in the room [DAG](development/room-dag-concepts.md), +- persisting them to the DB, and finally +- updating the events stream. + +Because load is sharded in this way, you *must* restart all worker instances when +adding or removing event persisters. + +An `event_persister` should not be mistaken for an `event_creator`. +An `event_creator` listens for requests from clients to create new events and does +so. It will then pass those events over HTTP replication to any configured event +persisters (or the main process if none are configured). + +Note that `event_creator`s and `event_persister`s are implemented using the same +[`synapse.app.generic_worker`](#synapse.app.generic_worker). + +An example [`stream_writers`](usage/configuration/config_documentation.md#stream_writers) configuration with multiple writers: ```yaml @@ -416,16 +439,18 @@ worker. Background tasks are run periodically or started via replication. Exactl which tasks are configured to run depends on your Synapse configuration (e.g. if stats is enabled). This worker doesn't handle any REST endpoints itself. -To enable this, the worker must have a `worker_name` and can be configured to run -background tasks. For example, to move background tasks to a dedicated worker, -the shared configuration would include: +To enable this, the worker must have a unique +[`worker_name`](usage/configuration/config_documentation.md#worker_name) +and can be configured to run background tasks. For example, to move background tasks +to a dedicated worker, the shared configuration would include: ```yaml run_background_tasks_on: background_worker ``` -You might also wish to investigate the `update_user_directory_from_worker` and -`media_instance_running_background_jobs` settings. +You might also wish to investigate the +[`update_user_directory_from_worker`](#updating-the-user-directory) and +[`media_instance_running_background_jobs`](#synapseappmedia_repository) settings. An example for a dedicated background worker instance: @@ -478,13 +503,17 @@ worker application type. ### `synapse.app.pusher` Handles sending push notifications to sygnal and email. Doesn't handle any -REST endpoints itself, but you should set `start_pushers: False` in the +REST endpoints itself, but you should set +[`start_pushers: false`](usage/configuration/config_documentation.md#start_pushers) in the shared configuration file to stop the main synapse sending push notifications. -To run multiple instances at once the `pusher_instances` option should list all -pusher instances by their worker name, e.g.: +To run multiple instances at once the +[`pusher_instances`](usage/configuration/config_documentation.md#pusher_instances) +option should list all pusher instances by their +[`worker_name`](usage/configuration/config_documentation.md#worker_name), e.g.: ```yaml +start_pushers: false pusher_instances: - pusher_worker1 - pusher_worker2 @@ -512,15 +541,20 @@ Note this worker cannot be load-balanced: only one instance should be active. ### `synapse.app.federation_sender` Handles sending federation traffic to other servers. Doesn't handle any -REST endpoints itself, but you should set `send_federation: False` in the -shared configuration file to stop the main synapse sending this traffic. +REST endpoints itself, but you should set +[`send_federation: false`](usage/configuration/config_documentation.md#send_federation) +in the shared configuration file to stop the main synapse sending this traffic. If running multiple federation senders then you must list each -instance in the `federation_sender_instances` option by their `worker_name`. +instance in the +[`federation_sender_instances`](usage/configuration/config_documentation.md#federation_sender_instances) +option by their +[`worker_name`](usage/configuration/config_documentation.md#worker_name). All instances must be stopped and started when adding or removing instances. For example: ```yaml +send_federation: false federation_sender_instances: - federation_sender1 - federation_sender2 @@ -547,7 +581,9 @@ Handles the media repository. It can handle all endpoints starting with: ^/_synapse/admin/v1/quarantine_media/.*$ ^/_synapse/admin/v1/users/.*/media$ -You should also set `enable_media_repo: False` in the shared configuration +You should also set +[`enable_media_repo: False`](usage/configuration/config_documentation.md#enable_media_repo) +in the shared configuration file to stop the main synapse running background jobs related to managing the media repository. Note that doing so will prevent the main process from being able to handle the above endpoints. diff --git a/synapse/config/logger.py b/synapse/config/logger.py index b62b3b9205..94d1150415 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -53,7 +53,7 @@ DEFAULT_LOG_CONFIG = Template( # Synapse also supports structured logging for machine readable logs which can # be ingested by ELK stacks. See [2] for details. # -# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema +# [1]: https://docs.python.org/3/library/logging.config.html#configuration-dictionary-schema # [2]: https://matrix-org.github.io/synapse/latest/structured_logging.html version: 1 -- cgit 1.4.1 From 67583281e3f8ea923eedbc56a4c85c7ba75d1582 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Oct 2022 09:58:12 -0400 Subject: Fix tests for change in PostgreSQL 14 behavior change. (#14310) PostgreSQL 14 changed the behavior of `websearch_to_tsquery` to improve some behaviour. The tests were hitting those edge-cases about handling of hanging double quotes. This fixes the tests to take into account the PostgreSQL version. --- changelog.d/14310.feature | 1 + synapse/storage/databases/main/search.py | 5 ++--- tests/storage/test_room_search.py | 16 ++++++++++++---- 3 files changed, 15 insertions(+), 7 deletions(-) create mode 100644 changelog.d/14310.feature (limited to 'synapse') diff --git a/changelog.d/14310.feature b/changelog.d/14310.feature new file mode 100644 index 0000000000..94c8a83212 --- /dev/null +++ b/changelog.d/14310.feature @@ -0,0 +1 @@ +Allow use of postgres and sqllite full-text search operators in search queries. \ No newline at end of file diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index a89fc54c2c..594b935614 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -824,9 +824,8 @@ def _tokenize_query(query: str) -> TokenList: in_phrase = False parts = deque(query.split('"')) for i, part in enumerate(parts): - # The contents inside double quotes is treated as a phrase, a trailing - # double quote is not implied. - in_phrase = bool(i % 2) and i != (len(parts) - 1) + # The contents inside double quotes is treated as a phrase. + in_phrase = bool(i % 2) # Pull out the individual words, discarding any non-word characters. words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE)) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 9ddc19900a..868b5bee84 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -239,7 +239,6 @@ class MessageSearchTest(HomeserverTestCase): ("fox -nope", (True, False)), ("fox -brown", (False, True)), ('"fox" quick', True), - ('"fox quick', True), ('"quick brown', True), ('" quick "', True), ('" nope"', False), @@ -269,6 +268,15 @@ class MessageSearchTest(HomeserverTestCase): response = self.helper.send(self.room_id, self.PHRASE, tok=self.access_token) self.assertIn("event_id", response) + # The behaviour of a missing trailing double quote changed in PostgreSQL 14 + # from ignoring the initial double quote to treating it as a phrase. + main_store = homeserver.get_datastores().main + found = False + if isinstance(main_store.database_engine, PostgresEngine): + assert main_store.database_engine._version is not None + found = main_store.database_engine._version < 140000 + self.COMMON_CASES.append(('"fox quick', (found, True))) + def test_tokenize_query(self) -> None: """Test the custom logic to tokenize a user's query.""" cases = ( @@ -280,9 +288,9 @@ class MessageSearchTest(HomeserverTestCase): ("fox -brown", ["fox", SearchToken.Not, "brown"]), ("- fox", [SearchToken.Not, "fox"]), ('"fox" quick', [Phrase(["fox"]), SearchToken.And, "quick"]), - # No trailing double quoe. - ('"fox quick', ["fox", SearchToken.And, "quick"]), - ('"-fox quick', [SearchToken.Not, "fox", SearchToken.And, "quick"]), + # No trailing double quote. + ('"fox quick', [Phrase(["fox", "quick"])]), + ('"-fox quick', [Phrase(["-fox", "quick"])]), ('" quick "', [Phrase(["quick"])]), ( 'q"uick brow"n', -- cgit 1.4.1 From aa70556699e649f46f51a198fb104eecdc0d311b Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 27 Oct 2022 13:29:23 -0500 Subject: Check appservice user interest against the local users instead of all users (`get_users_in_room` mis-use) (#13958) --- changelog.d/13958.bugfix | 1 + docs/upgrade.md | 19 ++++ synapse/appservice/__init__.py | 16 ++- synapse/storage/databases/main/appservice.py | 17 ++- synapse/storage/databases/main/roommember.py | 3 + tests/appservice/test_appservice.py | 10 +- tests/handlers/test_appservice.py | 162 ++++++++++++++++++++++++++- 7 files changed, 214 insertions(+), 14 deletions(-) create mode 100644 changelog.d/13958.bugfix (limited to 'synapse') diff --git a/changelog.d/13958.bugfix b/changelog.d/13958.bugfix new file mode 100644 index 0000000000..f9f651bfdc --- /dev/null +++ b/changelog.d/13958.bugfix @@ -0,0 +1 @@ +Check appservice user interest against the local users instead of all users in the room to align with [MSC3905](https://github.com/matrix-org/matrix-spec-proposals/pull/3905). diff --git a/docs/upgrade.md b/docs/upgrade.md index 78c34d0c15..f095bbc3a6 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -97,6 +97,25 @@ As announced with the release of [Synapse 1.69.0](#deprecation-of-the-generate_s Modules relying on it can instead use the `create_login_token` method. +## Changes to the events received by application services (interest) + +To align with spec (changed in +[MSC3905](https://github.com/matrix-org/matrix-spec-proposals/pull/3905)), Synapse now +only considers local users to be interesting. In other words, the `users` namespace +regex is only be applied against local users of the homeserver. + +Please note, this probably doesn't affect the expected behavior of your application +service, since an interesting local user in a room still means all messages in the room +(from local or remote users) will still be considered interesting. And matching a room +with the `rooms` or `aliases` namespace regex will still consider all events sent in the +room to be interesting to the application service. + +If one of your application service's `users` regex was intending to match a remote user, +this will no longer match as you expect. The behavioral mismatch between matching all +local users and some remote users is why the spec was changed/clarified and this +caveat is no longer supported. + + # Upgrading to v1.69.0 ## Changes to the receipts replication streams diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 0dfa00df44..500bdde3a9 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -172,12 +172,24 @@ class ApplicationService: Returns: True if this service would like to know about this room. """ - member_list = await store.get_users_in_room( + # We can use `get_local_users_in_room(...)` here because an application service + # can only be interested in local users of the server it's on (ignore any remote + # users that might match the user namespace regex). + # + # In the future, we can consider re-using + # `store.get_app_service_users_in_room` which is very similar to this + # function but has a slightly worse performance than this because we + # have an early escape-hatch if we find a single user that the + # appservice is interested in. The juice would be worth the squeeze if + # `store.get_app_service_users_in_room` was used in more places besides + # an experimental MSC. But for now we can avoid doing more work and + # barely using it later. + local_user_ids = await store.get_local_users_in_room( room_id, on_invalidate=cache_context.invalidate ) # check joined member events - for user_id in member_list: + for user_id in local_user_ids: if self.is_interested_in_user(user_id): return True return False diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 64b70a7b28..63046c0527 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -157,10 +157,23 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): app_service: "ApplicationService", cache_context: _CacheContext, ) -> List[str]: - users_in_room = await self.get_users_in_room( + """ + Get all users in a room that the appservice controls. + + Args: + room_id: The room to check in. + app_service: The application service to check interest/control against + + Returns: + List of user IDs that the appservice controls. + """ + # We can use `get_local_users_in_room(...)` here because an application service + # can only be interested in local users of the server it's on (ignore any remote + # users that might match the user namespace regex). + local_users_in_room = await self.get_local_users_in_room( room_id, on_invalidate=cache_context.invalidate ) - return list(filter(app_service.is_interested_in_user, users_in_room)) + return list(filter(app_service.is_interested_in_user, local_users_in_room)) class ApplicationServiceStore(ApplicationServiceWorkerStore): diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index ab708b0ba5..e56a13f21e 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -152,6 +152,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): the forward extremities of those rooms will exclude most members. We may also calculate room state incorrectly for such rooms and believe that a member is or is not in the room when the opposite is true. + + Note: If you only care about users in the room local to the homeserver, use + `get_local_users_in_room(...)` instead which will be more performant. """ return await self.db_pool.simple_select_onecol( table="current_state_events", diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 3018d3fc6f..d4dccfc2f0 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -43,7 +43,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store = Mock() self.store.get_aliases_for_room = simple_async_mock([]) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) @defer.inlineCallbacks def test_regex_user_id_prefix_match(self): @@ -129,7 +129,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store.get_aliases_for_room = simple_async_mock( ["#irc_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertTrue( ( yield defer.ensureDeferred( @@ -184,7 +184,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store.get_aliases_for_room = simple_async_mock( ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertFalse( ( yield defer.ensureDeferred( @@ -203,7 +203,7 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"]) - self.store.get_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = simple_async_mock([]) self.assertTrue( ( yield defer.ensureDeferred( @@ -236,7 +236,7 @@ class ApplicationServiceTestCase(unittest.TestCase): def test_member_list_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. - self.store.get_users_in_room = simple_async_mock( + self.store.get_local_users_in_room = simple_async_mock( ["@alice:here", "@irc_fo:here", "@bob:here"] ) self.store.get_aliases_for_room = simple_async_mock([]) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 7e4570f990..144e49d0fd 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -22,7 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin import synapse.storage -from synapse.api.constants import EduTypes +from synapse.api.constants import EduTypes, EventTypes from synapse.appservice import ( ApplicationService, TransactionOneTimeKeyCounts, @@ -36,7 +36,7 @@ from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest -from tests.test_utils import make_awaitable, simple_async_mock +from tests.test_utils import event_injection, make_awaitable, simple_async_mock from tests.unittest import override_config from tests.utils import MockClock @@ -390,15 +390,16 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): receipts.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.hs = hs # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track any outgoing ephemeral events self.send_mock = simple_async_mock() - hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] return_value=self._services ) @@ -416,6 +417,157 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): "exclusive_as_user", "password", self.exclusive_as_user_device_id ) + def _notify_interested_services(self): + # This is normally set in `notify_interested_services` but we need to call the + # internal async version so the reactor gets pushed to completion. + self.hs.get_application_service_handler().current_max += 1 + self.get_success( + self.hs.get_application_service_handler()._notify_interested_services( + RoomStreamToken( + None, self.hs.get_application_service_handler().current_max + ) + ) + ) + + @parameterized.expand( + [ + ("@local_as_user:test", True), + # Defining remote users in an application service user namespace regex is a + # footgun since the appservice might assume that it'll receive all events + # sent by that remote user, but it will only receive events in rooms that + # are shared with a local user. So we just remove this footgun possibility + # entirely and we won't notify the application service based on remote + # users. + ("@remote_as_user:remote", False), + ] + ) + def test_match_interesting_room_members( + self, interesting_user: str, should_notify: bool + ): + """ + Test to make sure that a interesting user (local or remote) in the room is + notified as expected when someone else in the room sends a message. + """ + # Register an application service that's interested in the `interesting_user` + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": interesting_user, + "exclusive": False, + }, + ], + }, + ) + + # Create a room + alice = self.register_user("alice", "pass") + alice_access_token = self.login("alice", "pass") + room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token) + + # Join the interesting user to the room + self.get_success( + event_injection.inject_member_event( + self.hs, room_id, interesting_user, "join" + ) + ) + # Kick the appservice into checking this membership event to get the event out + # of the way + self._notify_interested_services() + # We don't care about the interesting user join event (this test is making sure + # the next thing works) + self.send_mock.reset_mock() + + # Send a message from an uninteresting user + self.helper.send_event( + room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message from uninteresting user", + }, + tok=alice_access_token, + ) + # Kick the appservice into checking this new event + self._notify_interested_services() + + if should_notify: + self.send_mock.assert_called_once() + ( + service, + events, + _ephemeral, + _to_device_messages, + _otks, + _fbks, + _device_list_summary, + ) = self.send_mock.call_args[0] + + # Even though the message came from an uninteresting user, it should still + # notify us because the interesting user is joined to the room where the + # message was sent. + self.assertEqual(service, interested_appservice) + self.assertEqual(events[0]["type"], "m.room.message") + self.assertEqual(events[0]["sender"], alice) + else: + self.send_mock.assert_not_called() + + def test_application_services_receive_events_sent_by_interesting_local_user(self): + """ + Test to make sure that a messages sent from a local user can be interesting and + picked up by the appservice. + """ + # Register an application service that's interested in all local users + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": ".*", + "exclusive": False, + }, + ], + }, + ) + + # Create a room + alice = self.register_user("alice", "pass") + alice_access_token = self.login("alice", "pass") + room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token) + + # We don't care about interesting events before this (this test is making sure + # the next thing works) + self.send_mock.reset_mock() + + # Send a message from the interesting local user + self.helper.send_event( + room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message from interesting local user", + }, + tok=alice_access_token, + ) + # Kick the appservice into checking this new event + self._notify_interested_services() + + self.send_mock.assert_called_once() + ( + service, + events, + _ephemeral, + _to_device_messages, + _otks, + _fbks, + _device_list_summary, + ) = self.send_mock.call_args[0] + + # Events sent from an interesting local user should also be picked up as + # interesting to the appservice. + self.assertEqual(service, interested_appservice) + self.assertEqual(events[0]["type"], "m.room.message") + self.assertEqual(events[0]["sender"], alice) + def test_sending_read_receipt_batches_to_application_services(self): """Tests that a large batch of read receipts are sent correctly to interested application services. -- cgit 1.4.1 From 6a6e1e8c0711939338f25d8d41d1e4d33d984949 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 28 Oct 2022 10:53:34 +0000 Subject: Fix room creation being rate limited too aggressively since Synapse v1.69.0. (#14314) * Introduce a test for the old behaviour which we want to restore * Reintroduce the old behaviour in a simpler way * Newsfile Signed-off-by: Olivier Wilkinson (reivilibre) * Use 1 credit instead of 2 for creating a room: be more lenient than before Notably, the UI in Element Web was still broken after restoring to prior behaviour. After discussion, we agreed that it would be sensible to increase the limit. Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/14314.bugfix | 1 + synapse/api/ratelimiting.py | 8 +++++- synapse/handlers/room.py | 16 ++++++++---- tests/rest/client/test_rooms.py | 54 ++++++++++++++++++++++++++++++++++++++--- 4 files changed, 70 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14314.bugfix (limited to 'synapse') diff --git a/changelog.d/14314.bugfix b/changelog.d/14314.bugfix new file mode 100644 index 0000000000..8be47ee083 --- /dev/null +++ b/changelog.d/14314.bugfix @@ -0,0 +1 @@ +Fix room creation being rate limited too aggressively since Synapse v1.69.0. \ No newline at end of file diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 044c7d4926..511790c7c5 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -343,6 +343,7 @@ class RequestRatelimiter: requester: Requester, update: bool = True, is_admin_redaction: bool = False, + n_actions: int = 1, ) -> None: """Ratelimits requests. @@ -355,6 +356,8 @@ class RequestRatelimiter: is_admin_redaction: Whether this is a room admin/moderator redacting an event. If so then we may apply different ratelimits depending on config. + n_actions: Multiplier for the number of actions to apply to the + rate limiter at once. Raises: LimitExceededError if the request should be ratelimited @@ -383,7 +386,9 @@ class RequestRatelimiter: if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate # ratelimiter as to not have user_ids clash - await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) + await self.admin_redaction_ratelimiter.ratelimit( + requester, update=update, n_actions=n_actions + ) else: # Override rate and burst count per-user await self.request_ratelimiter.ratelimit( @@ -391,4 +396,5 @@ class RequestRatelimiter: rate_hz=messages_per_second, burst_count=burst_count, update=update, + n_actions=n_actions, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 638f54051a..d74b675adc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -559,7 +559,6 @@ class RoomCreationHandler: invite_list=[], initial_state=initial_state, creation_content=creation_content, - ratelimit=False, ) # Transfer membership events @@ -753,6 +752,10 @@ class RoomCreationHandler: ) if ratelimit: + # Rate limit once in advance, but don't rate limit the individual + # events in the room — room creation isn't atomic and it's very + # janky if half the events in the initial state don't make it because + # of rate limiting. await self.request_ratelimiter.ratelimit(requester) room_version_id = config.get( @@ -913,7 +916,6 @@ class RoomCreationHandler: room_alias=room_alias, power_level_content_override=power_level_content_override, creator_join_profile=creator_join_profile, - ratelimit=ratelimit, ) if "name" in config: @@ -1037,7 +1039,6 @@ class RoomCreationHandler: room_alias: Optional[RoomAlias] = None, power_level_content_override: Optional[JsonDict] = None, creator_join_profile: Optional[JsonDict] = None, - ratelimit: bool = True, ) -> Tuple[int, str, int]: """Sends the initial events into a new room. Sends the room creation, membership, and power level events into the room sequentially, then creates and batches up the @@ -1046,6 +1047,8 @@ class RoomCreationHandler: `power_level_content_override` doesn't apply when initial state has power level state event content. + Rate limiting should already have been applied by this point. + Returns: A tuple containing the stream ID, event ID and depth of the last event sent to the room. @@ -1144,7 +1147,7 @@ class RoomCreationHandler: creator.user, room_id, "join", - ratelimit=ratelimit, + ratelimit=False, content=creator_join_profile, new_room=True, prev_event_ids=[last_sent_event_id], @@ -1269,7 +1272,10 @@ class RoomCreationHandler: events_to_send.append((encryption_event, encryption_context)) last_event = await self.event_creation_handler.handle_new_client_event( - creator, events_to_send, ignore_shadow_ban=True + creator, + events_to_send, + ignore_shadow_ban=True, + ratelimit=False, ) assert last_event.internal_metadata.stream_ordering is not None return last_event.internal_metadata.stream_ordering, last_event.event_id, depth diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 716366eb90..1084d4ad9d 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -54,6 +54,7 @@ from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable from tests.test_utils.event_injection import create_event +from tests.unittest import override_config PATH_PREFIX = b"/_matrix/client/api/v1" @@ -871,6 +872,41 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) + def _create_basic_room(self) -> Tuple[int, object]: + """ + Tries to create a basic room and returns the response code. + """ + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + return channel.code, channel.json_body + + @override_config( + { + "rc_message": {"per_second": 0.2, "burst_count": 10}, + } + ) + def test_room_creation_ratelimiting(self) -> None: + """ + Regression test for #14312, where ratelimiting was made too strict. + Clients should be able to create 10 rooms in a row + without hitting rate limits, using default rate limit config. + (We override rate limiting config back to its default value.) + + To ensure we don't make ratelimiting too generous accidentally, + also check that we can't create an 11th room. + """ + + for _ in range(10): + code, json_body = self._create_basic_room() + self.assertEqual(code, HTTPStatus.OK, json_body) + + # The 6th room hits the rate limit. + code, json_body = self._create_basic_room() + self.assertEqual(code, HTTPStatus.TOO_MANY_REQUESTS, json_body) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -1390,10 +1426,22 @@ class RoomJoinRatelimitTestCase(RoomBase): ) def test_join_local_ratelimit(self) -> None: """Tests that local joins are actually rate-limited.""" - for _ in range(3): - self.helper.create_room_as(self.user_id) + # Create 4 rooms + room_ids = [ + self.helper.create_room_as(self.user_id, is_public=True) for _ in range(4) + ] + + joiner_user_id = self.register_user("joiner", "secret") + # Now make a new user try to join some of them. - self.helper.create_room_as(self.user_id, expect_code=429) + # The user can join 3 rooms + for room_id in room_ids[0:3]: + self.helper.join(room_id, joiner_user_id) + + # But the user cannot join a 4th room + self.helper.join( + room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS + ) @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} -- cgit 1.4.1 From 81815e0561eea91dbf0c29731589fac2e6f98a40 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Oct 2022 11:44:10 -0400 Subject: Switch search SQL to triple-quote strings. (#14311) For ease of reading we switch from concatenated strings to triple quote strings. --- changelog.d/14311.feature | 1 + synapse/storage/databases/main/search.py | 188 ++++++++++++++++--------------- 2 files changed, 100 insertions(+), 89 deletions(-) create mode 100644 changelog.d/14311.feature (limited to 'synapse') diff --git a/changelog.d/14311.feature b/changelog.d/14311.feature new file mode 100644 index 0000000000..94c8a83212 --- /dev/null +++ b/changelog.d/14311.feature @@ -0,0 +1 @@ +Allow use of postgres and sqllite full-text search operators in search queries. \ No newline at end of file diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 594b935614..e9588d1755 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -80,11 +80,11 @@ class SearchWorkerStore(SQLBaseStore): if not self.hs.config.server.enable_search: return if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" - ) + sql = """ + INSERT INTO event_search + (event_id, room_id, key, vector, stream_ordering, origin_server_ts) + VALUES (?,?,?,to_tsvector('english', ?),?,?) + """ args1 = ( ( @@ -101,20 +101,20 @@ class SearchWorkerStore(SQLBaseStore): txn.execute_batch(sql, args1) elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - args2 = ( - ( - entry.event_id, - entry.room_id, - entry.key, - _clean_value_for_search(entry.value), - ) - for entry in entries + self.db_pool.simple_insert_many_txn( + txn, + table="event_search", + keys=("event_id", "room_id", "key", "value"), + values=( + ( + entry.event_id, + entry.room_id, + entry.key, + _clean_value_for_search(entry.value), + ) + for entry in entries + ), ) - txn.execute_batch(sql, args2) else: # This should be unreachable. @@ -162,15 +162,17 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): TYPES = ["m.room.name", "m.room.message", "m.room.topic"] def reindex_search_txn(txn: LoggingTransaction) -> int: - sql = ( - "SELECT stream_ordering, event_id, room_id, type, json, " - " origin_server_ts FROM events" - " JOIN event_json USING (room_id, event_id)" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " AND (%s)" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),) + sql = """ + SELECT stream_ordering, event_id, room_id, type, json, origin_server_ts + FROM events + JOIN event_json USING (room_id, event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND (%s) + ORDER BY stream_ordering DESC + LIMIT ? + """ % ( + " OR ".join("type = '%s'" % (t,) for t in TYPES), + ) txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) @@ -284,8 +286,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): try: c.execute( - "CREATE INDEX CONCURRENTLY event_search_fts_idx" - " ON event_search USING GIN (vector)" + """ + CREATE INDEX CONCURRENTLY event_search_fts_idx + ON event_search USING GIN (vector) + """ ) except psycopg2.ProgrammingError as e: logger.warning( @@ -323,12 +327,16 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): # We create with NULLS FIRST so that when we search *backwards* # we get the ones with non null origin_server_ts *first* c.execute( - "CREATE INDEX CONCURRENTLY event_search_room_order ON event_search(" - "room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + """ + CREATE INDEX CONCURRENTLY event_search_room_order + ON event_search(room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST) + """ ) c.execute( - "CREATE INDEX CONCURRENTLY event_search_order ON event_search(" - "origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + """ + CREATE INDEX CONCURRENTLY event_search_order + ON event_search(origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST) + """ ) conn.set_session(autocommit=False) @@ -345,14 +353,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): ) def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]: - sql = ( - "UPDATE event_search AS es SET stream_ordering = e.stream_ordering," - " origin_server_ts = e.origin_server_ts" - " FROM events AS e" - " WHERE e.event_id = es.event_id" - " AND ? <= e.stream_ordering AND e.stream_ordering < ?" - " RETURNING es.stream_ordering" - ) + sql = """ + UPDATE event_search AS es + SET stream_ordering = e.stream_ordering, origin_server_ts = e.origin_server_ts + FROM events AS e + WHERE e.event_id = es.event_id + AND ? <= e.stream_ordering AND e.stream_ordering < ? + RETURNING es.stream_ordering + """ min_stream_id = max_stream_id - batch_size txn.execute(sql, (min_stream_id, max_stream_id)) @@ -456,33 +464,33 @@ class SearchStore(SearchBackgroundUpdateStore): if isinstance(self.database_engine, PostgresEngine): search_query = search_term tsquery_func = self.database_engine.tsquery_func - sql = ( - f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank," - " room_id, event_id" - " FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?)" - ) + sql = f""" + SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank, + room_id, event_id + FROM event_search + WHERE vector @@ {tsquery_func}('english', ?) + """ args = [search_query, search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?)" - ) + count_sql = f""" + SELECT room_id, count(*) as count FROM event_search + WHERE vector @@ {tsquery_func}('english', ?) + """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_for_sqlite(search_term) - sql = ( - "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" - " FROM event_search" - " WHERE value MATCH ?" - ) + sql = """ + SELECT rank(matchinfo(event_search)) as rank, room_id, event_id + FROM event_search + WHERE value MATCH ? + """ args = [search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ?" - ) + count_sql = """ + SELECT room_id, count(*) as count FROM event_search + WHERE value MATCH ? + """ count_args = [search_query] + count_args else: # This should be unreachable. @@ -588,26 +596,27 @@ class SearchStore(SearchBackgroundUpdateStore): raise SynapseError(400, "Invalid pagination token") clauses.append( - "(origin_server_ts < ?" - " OR (origin_server_ts = ? AND stream_ordering < ?))" + """ + (origin_server_ts < ? OR (origin_server_ts = ? AND stream_ordering < ?)) + """ ) args.extend([origin_server_ts, origin_server_ts, stream]) if isinstance(self.database_engine, PostgresEngine): search_query = search_term tsquery_func = self.database_engine.tsquery_func - sql = ( - f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank," - " origin_server_ts, stream_ordering, room_id, event_id" - " FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?) AND " - ) + sql = f""" + SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank, + origin_server_ts, stream_ordering, room_id, event_id + FROM event_search + WHERE vector @@ {tsquery_func}('english', ?) AND + """ args = [search_query, search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?) AND " - ) + count_sql = f""" + SELECT room_id, count(*) as count FROM event_search + WHERE vector @@ {tsquery_func}('english', ?) AND + """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): @@ -619,23 +628,24 @@ class SearchStore(SearchBackgroundUpdateStore): # in the events table to get the topological ordering. We need # to use the indexes in this order because sqlite refuses to # MATCH unless it uses the full text search index - sql = ( - "SELECT rank(matchinfo) as rank, room_id, event_id," - " origin_server_ts, stream_ordering" - " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" - " FROM event_search" - " WHERE value MATCH ?" - " )" - " CROSS JOIN events USING (event_id)" - " WHERE " + sql = """ + SELECT + rank(matchinfo) as rank, room_id, event_id, origin_server_ts, stream_ordering + FROM ( + SELECT key, event_id, matchinfo(event_search) as matchinfo + FROM event_search + WHERE value MATCH ? ) + CROSS JOIN events USING (event_id) + WHERE + """ search_query = _parse_query_for_sqlite(search_term) args = [search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ? AND " - ) + count_sql = """ + SELECT room_id, count(*) as count FROM event_search + WHERE value MATCH ? AND + """ count_args = [search_query] + count_args else: # This should be unreachable. @@ -647,10 +657,10 @@ class SearchStore(SearchBackgroundUpdateStore): # We add an arbitrary limit here to ensure we don't try to pull the # entire table from the database. if isinstance(self.database_engine, PostgresEngine): - sql += ( - " ORDER BY origin_server_ts DESC NULLS LAST," - " stream_ordering DESC NULLS LAST LIMIT ?" - ) + sql += """ + ORDER BY origin_server_ts DESC NULLS LAST, stream_ordering DESC NULLS LAST + LIMIT ? + """ elif isinstance(self.database_engine, Sqlite3Engine): sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?" else: -- cgit 1.4.1 From 730b13dbc9e48181b1aaf38be870ec21364b1e9c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 28 Oct 2022 17:04:02 +0100 Subject: Improve `RawHeaders` type hints (#14303) --- changelog.d/14303.misc | 1 + synapse/app/generic_worker.py | 8 ++++---- synapse/http/client.py | 24 +++++++++++++++++++----- 3 files changed, 24 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14303.misc (limited to 'synapse') diff --git a/changelog.d/14303.misc b/changelog.d/14303.misc new file mode 100644 index 0000000000..24ce238223 --- /dev/null +++ b/changelog.d/14303.misc @@ -0,0 +1 @@ +Improve type hinting of `RawHeaders`. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 2a9f039367..cb5892f041 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -178,13 +178,13 @@ class KeyUploadServlet(RestServlet): # Proxy headers from the original request, such as the auth headers # (in case the access token is there) and the original IP / # User-Agent of the request. - headers = { - header: request.requestHeaders.getRawHeaders(header, []) + headers: Dict[bytes, List[bytes]] = { + header: list(request.requestHeaders.getRawHeaders(header, [])) for header in (b"Authorization", b"User-Agent") } # Add the previous hop to the X-Forwarded-For header. - x_forwarded_for = request.requestHeaders.getRawHeaders( - b"X-Forwarded-For", [] + x_forwarded_for = list( + request.requestHeaders.getRawHeaders(b"X-Forwarded-For", []) ) # we use request.client here, since we want the previous hop, not the # original client (as returned by request.getClientAddress()). diff --git a/synapse/http/client.py b/synapse/http/client.py index 084d0a5b84..4eb740c040 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -25,7 +25,6 @@ from typing import ( List, Mapping, Optional, - Sequence, Tuple, Union, ) @@ -90,14 +89,29 @@ incoming_responses_counter = Counter( "synapse_http_client_responses", "", ["method", "code"] ) -# the type of the headers list, to be passed to the t.w.h.Headers. -# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so -# we simplify. +# the type of the headers map, to be passed to the t.w.h.Headers. +# +# The actual type accepted by Twisted is +# Mapping[Union[str, bytes], Sequence[Union[str, bytes]] , +# allowing us to mix and match str and bytes freely. However: any str is also a +# Sequence[str]; passing a header string value which is a +# standalone str is interpreted as a sequence of 1-codepoint strings. This is a disastrous footgun. +# We use a narrower value type (RawHeaderValue) to avoid this footgun. +# +# We also simplify the keys to be either all str or all bytes. This helps because +# Dict[K, V] is invariant in K (and indeed V). RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]] # the value actually has to be a List, but List is invariant so we can't specify that # the entries can either be Lists or bytes. -RawHeaderValue = Sequence[Union[str, bytes]] +RawHeaderValue = Union[ + List[str], + List[bytes], + List[Union[str, bytes]], + Tuple[str, ...], + Tuple[bytes, ...], + Tuple[Union[str, bytes], ...], +] def check_against_blacklist( -- cgit 1.4.1 From 7911e2835df7b4bf1dec98b09da89beda65e2ab2 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 28 Oct 2022 18:06:02 +0100 Subject: Prevent federation user keys query from returning device names if disallowed (#14304) --- changelog.d/14304.bugfix | 1 + synapse/handlers/e2e_keys.py | 37 ++++++++++++++++++++--- synapse/storage/databases/main/end_to_end_keys.py | 17 ++++++++--- 3 files changed, 46 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14304.bugfix (limited to 'synapse') diff --git a/changelog.d/14304.bugfix b/changelog.d/14304.bugfix new file mode 100644 index 0000000000..b8d4d91034 --- /dev/null +++ b/changelog.d/14304.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.34.0 where device names would be returned via a federation user key query request when `allow_device_name_lookup_over_federation` was set to `false`. \ No newline at end of file diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 09a2492afc..a9912c467d 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -49,6 +49,7 @@ logger = logging.getLogger(__name__) class E2eKeysHandler: def __init__(self, hs: "HomeServer"): + self.config = hs.config self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() @@ -431,13 +432,17 @@ class E2eKeysHandler: @trace @cancellable async def query_local_devices( - self, query: Mapping[str, Optional[List[str]]] + self, + query: Mapping[str, Optional[List[str]]], + include_displaynames: bool = True, ) -> Dict[str, Dict[str, dict]]: """Get E2E device keys for local users Args: query: map from user_id to a list of devices to query (None for all devices) + include_displaynames: Whether to include device displaynames in the returned + device details. Returns: A map from user_id -> device_id -> device details @@ -469,7 +474,9 @@ class E2eKeysHandler: # make sure that each queried user appears in the result dict result_dict[user_id] = {} - results = await self.store.get_e2e_device_keys_for_cs_api(local_query) + results = await self.store.get_e2e_device_keys_for_cs_api( + local_query, include_displaynames + ) # Build the result structure for user_id, device_keys in results.items(): @@ -482,11 +489,33 @@ class E2eKeysHandler: async def on_federation_query_client_keys( self, query_body: Dict[str, Dict[str, Optional[List[str]]]] ) -> JsonDict: - """Handle a device key query from a federated server""" + """Handle a device key query from a federated server: + + Handles the path: GET /_matrix/federation/v1/users/keys/query + + Args: + query_body: The body of the query request. Should contain a key + "device_keys" that map to a dictionary of user ID's -> list of + device IDs. If the list of device IDs is empty, all devices of + that user will be queried. + + Returns: + A json dictionary containing the following: + - device_keys: A dictionary containing the requested device information. + - master_keys: An optional dictionary of user ID -> master cross-signing + key info. + - self_signing_key: An optional dictionary of user ID -> self-signing + key info. + """ device_keys_query: Dict[str, Optional[List[str]]] = query_body.get( "device_keys", {} ) - res = await self.query_local_devices(device_keys_query) + res = await self.query_local_devices( + device_keys_query, + include_displaynames=( + self.config.federation.allow_device_name_lookup_over_federation + ), + ) ret = {"device_keys": res} # add in the cross-signing keys diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 8a10ae800c..2a4f58ed92 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -139,11 +139,15 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @trace @cancellable async def get_e2e_device_keys_for_cs_api( - self, query_list: List[Tuple[str, Optional[str]]] + self, + query_list: List[Tuple[str, Optional[str]]], + include_displaynames: bool = True, ) -> Dict[str, Dict[str, JsonDict]]: """Fetch a list of device keys, formatted suitably for the C/S API. Args: - query_list(list): List of pairs of user_ids and device_ids. + query_list: List of pairs of user_ids and device_ids. + include_displaynames: Whether to include the displayname of returned devices + (if one exists). Returns: Dict mapping from user-id to dict mapping from device_id to key data. The key data will be a dict in the same format as the @@ -166,9 +170,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker continue r["unsigned"] = {} - display_name = device_info.display_name - if display_name is not None: - r["unsigned"]["device_display_name"] = display_name + if include_displaynames: + # Include the device's display name in the "unsigned" dictionary + display_name = device_info.display_name + if display_name is not None: + r["unsigned"]["device_display_name"] = display_name + rv[user_id][device_id] = r return rv -- cgit 1.4.1 From 2bb2c32e8ed5642a5bf3ba1e8c49e10cecc88905 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 31 Oct 2022 13:02:07 +0000 Subject: Avoid incrementing bg process utime/stime counters by negative durations (#14323) --- changelog.d/14323.bugfix | 1 + mypy.ini | 4 +- synapse/metrics/background_process_metrics.py | 6 +- tests/metrics/__init__.py | 0 tests/metrics/test_background_process_metrics.py | 19 +++ tests/metrics/test_metrics.py | 206 +++++++++++++++++++++++ tests/test_metrics.py | 200 ---------------------- 7 files changed, 233 insertions(+), 203 deletions(-) create mode 100644 changelog.d/14323.bugfix create mode 100644 tests/metrics/__init__.py create mode 100644 tests/metrics/test_background_process_metrics.py create mode 100644 tests/metrics/test_metrics.py delete mode 100644 tests/test_metrics.py (limited to 'synapse') diff --git a/changelog.d/14323.bugfix b/changelog.d/14323.bugfix new file mode 100644 index 0000000000..da39bc020c --- /dev/null +++ b/changelog.d/14323.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 0.34.0rc2 where logs could include error spam when background processes are measured as taking a negative amount of time. diff --git a/mypy.ini b/mypy.ini index 34b4523e00..8f1141a239 100644 --- a/mypy.ini +++ b/mypy.ini @@ -56,7 +56,6 @@ exclude = (?x) |tests/rest/media/v1/test_media_storage.py |tests/server.py |tests/server_notices/test_resource_limits_server_notices.py - |tests/test_metrics.py |tests/test_state.py |tests/test_terms_auth.py |tests/util/caches/test_cached_call.py @@ -106,6 +105,9 @@ disallow_untyped_defs = False [mypy-tests.handlers.test_user_directory] disallow_untyped_defs = True +[mypy-tests.metrics.test_background_process_metrics] +disallow_untyped_defs = True + [mypy-tests.push.test_bulk_push_rule_evaluator] disallow_untyped_defs = True diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 7a1516d3a8..9ea4e23b31 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -174,8 +174,10 @@ class _BackgroundProcess: diff = new_stats - self._reported_stats self._reported_stats = new_stats - _background_process_ru_utime.labels(self.desc).inc(diff.ru_utime) - _background_process_ru_stime.labels(self.desc).inc(diff.ru_stime) + # For unknown reasons, the difference in times can be negative. See comment in + # synapse.http.request_metrics.RequestMetrics.update_metrics. + _background_process_ru_utime.labels(self.desc).inc(max(diff.ru_utime, 0)) + _background_process_ru_stime.labels(self.desc).inc(max(diff.ru_stime, 0)) _background_process_db_txn_count.labels(self.desc).inc(diff.db_txn_count) _background_process_db_txn_duration.labels(self.desc).inc( diff.db_txn_duration_sec diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/metrics/test_background_process_metrics.py b/tests/metrics/test_background_process_metrics.py new file mode 100644 index 0000000000..f0f6cb2912 --- /dev/null +++ b/tests/metrics/test_background_process_metrics.py @@ -0,0 +1,19 @@ +from unittest import TestCase as StdlibTestCase +from unittest.mock import Mock + +from synapse.logging.context import ContextResourceUsage, LoggingContext +from synapse.metrics.background_process_metrics import _BackgroundProcess + + +class TestBackgroundProcessMetrics(StdlibTestCase): + def test_update_metrics_with_negative_time_diff(self) -> None: + """We should ignore negative reported utime and stime differences""" + usage = ContextResourceUsage() + usage.ru_stime = usage.ru_utime = -1.0 + + mock_logging_context = Mock(spec=LoggingContext) + mock_logging_context.get_resource_usage.return_value = usage + + process = _BackgroundProcess("test process", mock_logging_context) + # Should not raise + process.update_metrics() diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py new file mode 100644 index 0000000000..bddc4228bc --- /dev/null +++ b/tests/metrics/test_metrics.py @@ -0,0 +1,206 @@ +# Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing_extensions import Protocol + +try: + from importlib import metadata +except ImportError: + import importlib_metadata as metadata # type: ignore[no-redef] + +from unittest.mock import patch + +from pkg_resources import parse_version + +from synapse.app._base import _set_prometheus_client_use_created_metrics +from synapse.metrics import REGISTRY, InFlightGauge, generate_latest +from synapse.util.caches.deferred_cache import DeferredCache + +from tests import unittest + + +def get_sample_labels_value(sample): + """Extract the labels and values of a sample. + + prometheus_client 0.5 changed the sample type to a named tuple with more + members than the plain tuple had in 0.4 and earlier. This function can + extract the labels and value from the sample for both sample types. + + Args: + sample: The sample to get the labels and value from. + Returns: + A tuple of (labels, value) from the sample. + """ + + # If the sample has a labels and value attribute, use those. + if hasattr(sample, "labels") and hasattr(sample, "value"): + return sample.labels, sample.value + # Otherwise fall back to treating it as a plain 3 tuple. + else: + _, labels, value = sample + return labels, value + + +class TestMauLimit(unittest.TestCase): + def test_basic(self): + class MetricEntry(Protocol): + foo: int + bar: int + + gauge: InFlightGauge[MetricEntry] = InFlightGauge( + "test1", "", labels=["test_label"], sub_metrics=["foo", "bar"] + ) + + def handle1(metrics): + metrics.foo += 2 + metrics.bar = max(metrics.bar, 5) + + def handle2(metrics): + metrics.foo += 3 + metrics.bar = max(metrics.bar, 7) + + gauge.register(("key1",), handle1) + + self.assert_dict( + { + "test1_total": {("key1",): 1}, + "test1_foo": {("key1",): 2}, + "test1_bar": {("key1",): 5}, + }, + self.get_metrics_from_gauge(gauge), + ) + + gauge.unregister(("key1",), handle1) + + self.assert_dict( + { + "test1_total": {("key1",): 0}, + "test1_foo": {("key1",): 0}, + "test1_bar": {("key1",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) + + gauge.register(("key1",), handle1) + gauge.register(("key2",), handle2) + + self.assert_dict( + { + "test1_total": {("key1",): 1, ("key2",): 1}, + "test1_foo": {("key1",): 2, ("key2",): 3}, + "test1_bar": {("key1",): 5, ("key2",): 7}, + }, + self.get_metrics_from_gauge(gauge), + ) + + gauge.unregister(("key2",), handle2) + gauge.register(("key1",), handle2) + + self.assert_dict( + { + "test1_total": {("key1",): 2, ("key2",): 0}, + "test1_foo": {("key1",): 5, ("key2",): 0}, + "test1_bar": {("key1",): 7, ("key2",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) + + def get_metrics_from_gauge(self, gauge): + results = {} + + for r in gauge.collect(): + results[r.name] = { + tuple(labels[x] for x in gauge.labels): value + for labels, value in map(get_sample_labels_value, r.samples) + } + + return results + + +class BuildInfoTests(unittest.TestCase): + def test_get_build(self): + """ + The synapse_build_info metric reports the OS version, Python version, + and Synapse version. + """ + items = list( + filter( + lambda x: b"synapse_build_info{" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + ) + self.assertEqual(len(items), 1) + self.assertTrue(b"osversion=" in items[0]) + self.assertTrue(b"pythonversion=" in items[0]) + self.assertTrue(b"version=" in items[0]) + + +class CacheMetricsTests(unittest.HomeserverTestCase): + def test_cache_metric(self): + """ + Caches produce metrics reflecting their state when scraped. + """ + CACHE_NAME = "cache_metrics_test_fgjkbdfg" + cache: DeferredCache[str, str] = DeferredCache(CACHE_NAME, max_entries=777) + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + + cache.prefill("1", "hi") + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + + +class PrometheusMetricsHackTestCase(unittest.HomeserverTestCase): + if parse_version(metadata.version("prometheus_client")) < parse_version("0.14.0"): + skip = "prometheus-client too old" + + def test_created_metrics_disabled(self) -> None: + """ + Tests that a brittle hack, to disable `_created` metrics, works. + This involves poking at the internals of prometheus-client. + It's not the end of the world if this doesn't work. + + This test gives us a way to notice if prometheus-client changes + their internals. + """ + import prometheus_client.metrics + + PRIVATE_FLAG_NAME = "_use_created" + + # By default, the pesky `_created` metrics are enabled. + # Check this assumption is still valid. + self.assertTrue(getattr(prometheus_client.metrics, PRIVATE_FLAG_NAME)) + + with patch("prometheus_client.metrics") as mock: + setattr(mock, PRIVATE_FLAG_NAME, True) + _set_prometheus_client_use_created_metrics(False) + self.assertFalse(getattr(mock, PRIVATE_FLAG_NAME, False)) diff --git a/tests/test_metrics.py b/tests/test_metrics.py deleted file mode 100644 index 1a70eddc9b..0000000000 --- a/tests/test_metrics.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2018 New Vector Ltd -# Copyright 2019 Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -try: - from importlib import metadata -except ImportError: - import importlib_metadata as metadata # type: ignore[no-redef] - -from unittest.mock import patch - -from pkg_resources import parse_version - -from synapse.app._base import _set_prometheus_client_use_created_metrics -from synapse.metrics import REGISTRY, InFlightGauge, generate_latest -from synapse.util.caches.deferred_cache import DeferredCache - -from tests import unittest - - -def get_sample_labels_value(sample): - """Extract the labels and values of a sample. - - prometheus_client 0.5 changed the sample type to a named tuple with more - members than the plain tuple had in 0.4 and earlier. This function can - extract the labels and value from the sample for both sample types. - - Args: - sample: The sample to get the labels and value from. - Returns: - A tuple of (labels, value) from the sample. - """ - - # If the sample has a labels and value attribute, use those. - if hasattr(sample, "labels") and hasattr(sample, "value"): - return sample.labels, sample.value - # Otherwise fall back to treating it as a plain 3 tuple. - else: - _, labels, value = sample - return labels, value - - -class TestMauLimit(unittest.TestCase): - def test_basic(self): - gauge = InFlightGauge( - "test1", "", labels=["test_label"], sub_metrics=["foo", "bar"] - ) - - def handle1(metrics): - metrics.foo += 2 - metrics.bar = max(metrics.bar, 5) - - def handle2(metrics): - metrics.foo += 3 - metrics.bar = max(metrics.bar, 7) - - gauge.register(("key1",), handle1) - - self.assert_dict( - { - "test1_total": {("key1",): 1}, - "test1_foo": {("key1",): 2}, - "test1_bar": {("key1",): 5}, - }, - self.get_metrics_from_gauge(gauge), - ) - - gauge.unregister(("key1",), handle1) - - self.assert_dict( - { - "test1_total": {("key1",): 0}, - "test1_foo": {("key1",): 0}, - "test1_bar": {("key1",): 0}, - }, - self.get_metrics_from_gauge(gauge), - ) - - gauge.register(("key1",), handle1) - gauge.register(("key2",), handle2) - - self.assert_dict( - { - "test1_total": {("key1",): 1, ("key2",): 1}, - "test1_foo": {("key1",): 2, ("key2",): 3}, - "test1_bar": {("key1",): 5, ("key2",): 7}, - }, - self.get_metrics_from_gauge(gauge), - ) - - gauge.unregister(("key2",), handle2) - gauge.register(("key1",), handle2) - - self.assert_dict( - { - "test1_total": {("key1",): 2, ("key2",): 0}, - "test1_foo": {("key1",): 5, ("key2",): 0}, - "test1_bar": {("key1",): 7, ("key2",): 0}, - }, - self.get_metrics_from_gauge(gauge), - ) - - def get_metrics_from_gauge(self, gauge): - results = {} - - for r in gauge.collect(): - results[r.name] = { - tuple(labels[x] for x in gauge.labels): value - for labels, value in map(get_sample_labels_value, r.samples) - } - - return results - - -class BuildInfoTests(unittest.TestCase): - def test_get_build(self): - """ - The synapse_build_info metric reports the OS version, Python version, - and Synapse version. - """ - items = list( - filter( - lambda x: b"synapse_build_info{" in x, - generate_latest(REGISTRY).split(b"\n"), - ) - ) - self.assertEqual(len(items), 1) - self.assertTrue(b"osversion=" in items[0]) - self.assertTrue(b"pythonversion=" in items[0]) - self.assertTrue(b"version=" in items[0]) - - -class CacheMetricsTests(unittest.HomeserverTestCase): - def test_cache_metric(self): - """ - Caches produce metrics reflecting their state when scraped. - """ - CACHE_NAME = "cache_metrics_test_fgjkbdfg" - cache = DeferredCache(CACHE_NAME, max_entries=777) - - items = { - x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") - for x in filter( - lambda x: b"cache_metrics_test_fgjkbdfg" in x, - generate_latest(REGISTRY).split(b"\n"), - ) - } - - self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") - self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") - - cache.prefill("1", "hi") - - items = { - x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") - for x in filter( - lambda x: b"cache_metrics_test_fgjkbdfg" in x, - generate_latest(REGISTRY).split(b"\n"), - ) - } - - self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") - self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") - - -class PrometheusMetricsHackTestCase(unittest.HomeserverTestCase): - if parse_version(metadata.version("prometheus_client")) < parse_version("0.14.0"): - skip = "prometheus-client too old" - - def test_created_metrics_disabled(self) -> None: - """ - Tests that a brittle hack, to disable `_created` metrics, works. - This involves poking at the internals of prometheus-client. - It's not the end of the world if this doesn't work. - - This test gives us a way to notice if prometheus-client changes - their internals. - """ - import prometheus_client.metrics - - PRIVATE_FLAG_NAME = "_use_created" - - # By default, the pesky `_created` metrics are enabled. - # Check this assumption is still valid. - self.assertTrue(getattr(prometheus_client.metrics, PRIVATE_FLAG_NAME)) - - with patch("prometheus_client.metrics") as mock: - setattr(mock, PRIVATE_FLAG_NAME, True) - _set_prometheus_client_use_created_metrics(False) - self.assertFalse(getattr(mock, PRIVATE_FLAG_NAME, False)) -- cgit 1.4.1 From cc3a52b33df72bb4230367536b924a6d1f510d36 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 31 Oct 2022 18:07:30 +0100 Subject: Support OIDC backchannel logouts (#11414) If configured an OIDC IdP can log a user's session out of Synapse when they log out of the identity provider. The IdP sends a request directly to Synapse (and must be configured with an endpoint) when a user logs out. --- changelog.d/11414.feature | 1 + docs/openid.md | 14 + docs/usage/configuration/config_documentation.md | 9 + synapse/config/oidc.py | 12 + synapse/handlers/oidc.py | 381 ++++++++++++++++++-- synapse/handlers/sso.py | 71 ++++ synapse/rest/synapse/client/oidc/__init__.py | 4 + .../client/oidc/backchannel_logout_resource.py | 35 ++ synapse/storage/databases/main/registration.py | 21 ++ tests/rest/client/test_auth.py | 390 +++++++++++++++++++-- tests/rest/client/utils.py | 55 ++- tests/server.py | 6 + tests/test_utils/oidc.py | 27 +- 13 files changed, 960 insertions(+), 66 deletions(-) create mode 100644 changelog.d/11414.feature create mode 100644 synapse/rest/synapse/client/oidc/backchannel_logout_resource.py (limited to 'synapse') diff --git a/changelog.d/11414.feature b/changelog.d/11414.feature new file mode 100644 index 0000000000..fc035e50a7 --- /dev/null +++ b/changelog.d/11414.feature @@ -0,0 +1 @@ +Support back-channel logouts from OpenID Connect providers. diff --git a/docs/openid.md b/docs/openid.md index 87ebea4c29..37c5eb244d 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -49,6 +49,13 @@ setting in your configuration file. See the [configuration manual](usage/configuration/config_documentation.md#oidc_providers) for some sample settings, as well as the text below for example configurations for specific providers. +## OIDC Back-Channel Logout + +Synapse supports receiving [OpenID Connect Back-Channel Logout](https://openid.net/specs/openid-connect-backchannel-1_0.html) notifications. + +This lets the OpenID Connect Provider notify Synapse when a user logs out, so that Synapse can end that user session. +This feature can be enabled by setting the `backchannel_logout_enabled` property to `true` in the provider configuration, and setting the following URL as destination for Back-Channel Logout notifications in your OpenID Connect Provider: `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout` + ## Sample configs Here are a few configs for providers that should work with Synapse. @@ -123,6 +130,9 @@ oidc_providers: [Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat. +Keycloak supports OIDC Back-Channel Logout, which sends logout notification to Synapse, so that Synapse users get logged out when they log out from Keycloak. +This can be optionally enabled by setting `backchannel_logout_enabled` to `true` in the Synapse configuration, and by setting the "Backchannel Logout URL" in Keycloak. + Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm. 1. Click `Clients` in the sidebar and click `Create` @@ -144,6 +154,8 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to | Client Protocol | `openid-connect` | | Access Type | `confidential` | | Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` | +| Backchannel Logout URL (optional) | `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout` | +| Backchannel Logout Session Required (optional) | `On` | 5. Click `Save` 6. On the Credentials tab, update the fields: @@ -167,7 +179,9 @@ oidc_providers: config: localpart_template: "{{ user.preferred_username }}" display_name_template: "{{ user.name }}" + backchannel_logout_enabled: true # Optional ``` + ### Auth0 [Auth0][auth0] is a hosted SaaS IdP solution. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 97fb505a5f..44358faf59 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3021,6 +3021,15 @@ Options for each entry include: which is set to the claims returned by the UserInfo Endpoint and/or in the ID Token. +* `backchannel_logout_enabled`: set to `true` to process OIDC Back-Channel Logout notifications. + Those notifications are expected to be received on `/_synapse/client/oidc/backchannel_logout`. + Defaults to `false`. + +* `backchannel_logout_ignore_sub`: by default, the OIDC Back-Channel Logout feature checks that the + `sub` claim matches the subject claim received during login. This check can be disabled by setting + this to `true`. Defaults to `false`. + + You might want to disable this if the `subject_claim` returned by the mapping provider is not `sub`. It is possible to configure Synapse to only allow logins if certain attributes match particular values in the OIDC userinfo. The requirements can be listed under diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 5418a332da..0bd83f4010 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -123,6 +123,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "userinfo_endpoint": {"type": "string"}, "jwks_uri": {"type": "string"}, "skip_verification": {"type": "boolean"}, + "backchannel_logout_enabled": {"type": "boolean"}, + "backchannel_logout_ignore_sub": {"type": "boolean"}, "user_profile_method": { "type": "string", "enum": ["auto", "userinfo_endpoint"], @@ -292,6 +294,10 @@ def _parse_oidc_config_dict( token_endpoint=oidc_config.get("token_endpoint"), userinfo_endpoint=oidc_config.get("userinfo_endpoint"), jwks_uri=oidc_config.get("jwks_uri"), + backchannel_logout_enabled=oidc_config.get("backchannel_logout_enabled", False), + backchannel_logout_ignore_sub=oidc_config.get( + "backchannel_logout_ignore_sub", False + ), skip_verification=oidc_config.get("skip_verification", False), user_profile_method=oidc_config.get("user_profile_method", "auto"), allow_existing_users=oidc_config.get("allow_existing_users", False), @@ -368,6 +374,12 @@ class OidcProviderConfig: # "openid" scope is used. jwks_uri: Optional[str] + # Whether Synapse should react to backchannel logouts + backchannel_logout_enabled: bool + + # Whether Synapse should ignore the `sub` claim in backchannel logouts or not. + backchannel_logout_ignore_sub: bool + # Whether to skip metadata verification skip_verification: bool diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 9759daf043..867973dcca 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -12,14 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import binascii import inspect +import json import logging -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Type, + TypeVar, + Union, +) from urllib.parse import urlencode, urlparse import attr +import unpaddedbase64 from authlib.common.security import generate_token -from authlib.jose import JsonWebToken, jwt +from authlib.jose import JsonWebToken, JWTClaims +from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri from authlib.oidc.core import CodeIDToken, UserInfo @@ -35,9 +49,12 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from twisted.web.http_headers import Headers +from synapse.api.errors import SynapseError from synapse.config import ConfigError from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig from synapse.handlers.sso import MappingException, UserAttributes +from synapse.http.server import finish_request +from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart @@ -88,6 +105,8 @@ class Token(TypedDict): #: there is no real point of doing this in our case. JWK = Dict[str, str] +C = TypeVar("C") + #: A JWK Set, as per RFC7517 sec 5. class JWKS(TypedDict): @@ -247,6 +266,80 @@ class OidcHandler: await oidc_provider.handle_oidc_callback(request, session_data, code) + async def handle_backchannel_logout(self, request: SynapseRequest) -> None: + """Handle an incoming request to /_synapse/client/oidc/backchannel_logout + + This extracts the logout_token from the request and tries to figure out + which OpenID Provider it is comming from. This works by matching the iss claim + with the issuer and the aud claim with the client_id. + + Since at this point we don't know who signed the JWT, we can't just + decode it using authlib since it will always verifies the signature. We + have to decode it manually without validating the signature. The actual JWT + verification is done in the `OidcProvider.handler_backchannel_logout` method, + once we figured out which provider sent the request. + + Args: + request: the incoming request from the browser. + """ + logout_token = parse_string(request, "logout_token") + if logout_token is None: + raise SynapseError(400, "Missing logout_token in request") + + # A JWT looks like this: + # header.payload.signature + # where all parts are encoded with urlsafe base64. + # The aud and iss claims we care about are in the payload part, which + # is a JSON object. + try: + # By destructuring the list after splitting, we ensure that we have + # exactly 3 segments + _, payload, _ = logout_token.split(".") + except ValueError: + raise SynapseError(400, "Invalid logout_token in request") + + try: + payload_bytes = unpaddedbase64.decode_base64(payload) + claims = json_decoder.decode(payload_bytes.decode("utf-8")) + except (json.JSONDecodeError, binascii.Error, UnicodeError): + raise SynapseError(400, "Invalid logout_token payload in request") + + try: + # Let's extract the iss and aud claims + iss = claims["iss"] + aud = claims["aud"] + # The aud claim can be either a string or a list of string. Here we + # normalize it as a list of strings. + if isinstance(aud, str): + aud = [aud] + + # Check that we have the right types for the aud and the iss claims + if not isinstance(iss, str) or not isinstance(aud, list): + raise TypeError() + for a in aud: + if not isinstance(a, str): + raise TypeError() + + # At this point we properly checked both claims types + issuer: str = iss + audience: List[str] = aud + except (TypeError, KeyError): + raise SynapseError(400, "Invalid issuer/audience in logout_token") + + # Now that we know the audience and the issuer, we can figure out from + # what provider it is coming from + oidc_provider: Optional[OidcProvider] = None + for provider in self._providers.values(): + if provider.issuer == issuer and provider.client_id in audience: + oidc_provider = provider + break + + if oidc_provider is None: + raise SynapseError(400, "Could not find the OP that issued this event") + + # Ask the provider to handle the logout request. + await oidc_provider.handle_backchannel_logout(request, logout_token) + class OidcError(Exception): """Used to catch errors when calling the token_endpoint""" @@ -342,6 +435,7 @@ class OidcProvider: self.idp_brand = provider.idp_brand self._sso_handler = hs.get_sso_handler() + self._device_handler = hs.get_device_handler() self._sso_handler.register_identity_provider(self) @@ -400,6 +494,41 @@ class OidcProvider: # If we're not using userinfo, we need a valid jwks to validate the ID token m.validate_jwks_uri() + if self._config.backchannel_logout_enabled: + if not m.get("backchannel_logout_supported", False): + logger.warning( + "OIDC Back-Channel Logout is enabled for issuer %r" + "but it does not advertise support for it", + self.issuer, + ) + + elif not m.get("backchannel_logout_session_supported", False): + logger.warning( + "OIDC Back-Channel Logout is enabled and supported " + "by issuer %r but it might not send a session ID with " + "logout tokens, which is required for the logouts to work", + self.issuer, + ) + + if not self._config.backchannel_logout_ignore_sub: + # If OIDC backchannel logouts are enabled, the provider mapping provider + # should use the `sub` claim. We verify that by mapping a dumb user and + # see if we get back the sub claim + user = UserInfo({"sub": "thisisasubject"}) + try: + subject = self._user_mapping_provider.get_remote_user_id(user) + if subject != user["sub"]: + raise ValueError("Unexpected subject") + except Exception: + logger.warning( + f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} " + "but it looks like the configured `user_mapping_provider` " + "does not use the `sub` claim as subject. If it is the case, " + "and you want Synapse to ignore the `sub` claim in OIDC " + "Back-Channel Logouts, set `backchannel_logout_ignore_sub` " + "to `true` in the issuer config." + ) + @property def _uses_userinfo(self) -> bool: """Returns True if the ``userinfo_endpoint`` should be used. @@ -415,6 +544,16 @@ class OidcProvider: or self._user_profile_method == "userinfo_endpoint" ) + @property + def issuer(self) -> str: + """The issuer identifying this provider.""" + return self._config.issuer + + @property + def client_id(self) -> str: + """The client_id used when interacting with this provider.""" + return self._config.client_id + async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata: """Return the provider metadata. @@ -662,6 +801,59 @@ class OidcProvider: return UserInfo(resp) + async def _verify_jwt( + self, + alg_values: List[str], + token: str, + claims_cls: Type[C], + claims_options: Optional[dict] = None, + claims_params: Optional[dict] = None, + ) -> C: + """Decode and validate a JWT, re-fetching the JWKS as needed. + + Args: + alg_values: list of `alg` values allowed when verifying the JWT. + token: the JWT. + claims_cls: the JWTClaims class to use to validate the claims. + claims_options: dict of options passed to the `claims_cls` constructor. + claims_params: dict of params passed to the `claims_cls` constructor. + + Returns: + The decoded claims in the JWT. + """ + jwt = JsonWebToken(alg_values) + + logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token) + + # Try to decode the keys in cache first, then retry by forcing the keys + # to be reloaded + jwk_set = await self.load_jwks() + try: + claims = jwt.decode( + token, + key=jwk_set, + claims_cls=claims_cls, + claims_options=claims_options, + claims_params=claims_params, + ) + except ValueError: + logger.info("Reloading JWKS after decode error") + jwk_set = await self.load_jwks(force=True) # try reloading the jwks + claims = jwt.decode( + token, + key=jwk_set, + claims_cls=claims_cls, + claims_options=claims_options, + claims_params=claims_params, + ) + + logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims) + + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew + return claims + async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: """Return an instance of UserInfo from token's ``id_token``. @@ -675,13 +867,13 @@ class OidcProvider: The decoded claims in the ID token. """ id_token = token.get("id_token") - logger.debug("Attempting to decode JWT id_token %r", id_token) # That has been theoritically been checked by the caller, so even though # assertion are not enabled in production, it is mainly here to appease mypy assert id_token is not None metadata = await self.load_metadata() + claims_params = { "nonce": nonce, "client_id": self._client_auth.client_id, @@ -691,38 +883,17 @@ class OidcProvider: # in the `id_token` that we can check against. claims_params["access_token"] = token["access_token"] - alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - jwt = JsonWebToken(alg_values) - - claim_options = {"iss": {"values": [metadata["issuer"]]}} + claims_options = {"iss": {"values": [metadata["issuer"]]}} - # Try to decode the keys in cache first, then retry by forcing the keys - # to be reloaded - jwk_set = await self.load_jwks() - try: - claims = jwt.decode( - id_token, - key=jwk_set, - claims_cls=CodeIDToken, - claims_options=claim_options, - claims_params=claims_params, - ) - except ValueError: - logger.info("Reloading JWKS after decode error") - jwk_set = await self.load_jwks(force=True) # try reloading the jwks - claims = jwt.decode( - id_token, - key=jwk_set, - claims_cls=CodeIDToken, - claims_options=claim_options, - claims_params=claims_params, - ) - - logger.debug("Decoded id_token JWT %r; validating", claims) + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - claims.validate( - now=self._clock.time(), leeway=120 - ) # allows 2 min of clock skew + claims = await self._verify_jwt( + alg_values=alg_values, + token=id_token, + claims_cls=CodeIDToken, + claims_options=claims_options, + claims_params=claims_params, + ) return claims @@ -1043,6 +1214,146 @@ class OidcProvider: # to be strings. return str(remote_user_id) + async def handle_backchannel_logout( + self, request: SynapseRequest, logout_token: str + ) -> None: + """Handle an incoming request to /_synapse/client/oidc/backchannel_logout + + The OIDC Provider posts a logout token to this endpoint when a user + session ends. That token is a JWT signed with the same keys as + ID tokens. The OpenID Connect Back-Channel Logout draft explains how to + validate the JWT and figure out what session to end. + + Args: + request: The request to respond to + logout_token: The logout token (a JWT) extracted from the request body + """ + # Back-Channel Logout can be disabled in the config, hence this check. + # This is not that important for now since Synapse is registered + # manually to the OP, so not specifying the backchannel-logout URI is + # as effective than disabling it here. It might make more sense if we + # support dynamic registration in Synapse at some point. + if not self._config.backchannel_logout_enabled: + logger.warning( + f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config" + ) + + # TODO: this responds with a 400 status code, which is what the OIDC + # Back-Channel Logout spec expects, but spec also suggests answering with + # a JSON object, with the `error` and `error_description` fields set, which + # we are not doing here. + # See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse + raise SynapseError( + 400, "OpenID Connect Back-Channel Logout is disabled for this provider" + ) + + metadata = await self.load_metadata() + + # As per OIDC Back-Channel Logout 1.0 sec. 2.4: + # A Logout Token MUST be signed and MAY also be encrypted. The same + # keys are used to sign and encrypt Logout Tokens as are used for ID + # Tokens. If the Logout Token is encrypted, it SHOULD replicate the + # iss (issuer) claim in the JWT Header Parameters, as specified in + # Section 5.3 of [JWT]. + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) + + # As per sec. 2.6: + # 3. Validate the iss, aud, and iat Claims in the same way they are + # validated in ID Tokens. + # Which means the audience should contain Synapse's client_id and the + # issuer should be the IdP issuer + claims_options = { + "iss": {"values": [metadata["issuer"]]}, + "aud": {"values": [self.client_id]}, + } + + try: + claims = await self._verify_jwt( + alg_values=alg_values, + token=logout_token, + claims_cls=LogoutToken, + claims_options=claims_options, + ) + except JoseError: + logger.exception("Invalid logout_token") + raise SynapseError(400, "Invalid logout_token") + + # As per sec. 2.6: + # 4. Verify that the Logout Token contains a sub Claim, a sid Claim, + # or both. + # 5. Verify that the Logout Token contains an events Claim whose + # value is JSON object containing the member name + # http://schemas.openid.net/event/backchannel-logout. + # 6. Verify that the Logout Token does not contain a nonce Claim. + # This is all verified by the LogoutToken claims class, so at this + # point the `sid` claim exists and is a string. + sid: str = claims.get("sid") + + # If the `sub` claim was included in the logout token, we check that it matches + # that it matches the right user. We can have cases where the `sub` claim is not + # the ID saved in database, so we let admins disable this check in config. + sub: Optional[str] = claims.get("sub") + expected_user_id: Optional[str] = None + if sub is not None and not self._config.backchannel_logout_ignore_sub: + expected_user_id = await self._store.get_user_by_external_id( + self.idp_id, sub + ) + + # Invalidate any running user-mapping sessions, in-flight login tokens and + # active devices + await self._sso_handler.revoke_sessions_for_provider_session_id( + auth_provider_id=self.idp_id, + auth_provider_session_id=sid, + expected_user_id=expected_user_id, + ) + + request.setResponseCode(200) + request.setHeader(b"Cache-Control", b"no-cache, no-store") + request.setHeader(b"Pragma", b"no-cache") + finish_request(request) + + +class LogoutToken(JWTClaims): + """ + Holds and verify claims of a logout token, as per + https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken + """ + + REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"] + + def validate(self, now: Optional[int] = None, leeway: int = 0) -> None: + """Validate everything in claims payload.""" + super().validate(now, leeway) + self.validate_sid() + self.validate_events() + self.validate_nonce() + + def validate_sid(self) -> None: + """Ensure the sid claim is present""" + sid = self.get("sid") + if not sid: + raise MissingClaimError("sid") + + if not isinstance(sid, str): + raise InvalidClaimError("sid") + + def validate_nonce(self) -> None: + """Ensure the nonce claim is absent""" + if "nonce" in self: + raise InvalidClaimError("nonce") + + def validate_events(self) -> None: + """Ensure the events claim is present and with the right value""" + events = self.get("events") + if not events: + raise MissingClaimError("events") + + if not isinstance(events, dict): + raise InvalidClaimError("events") + + if "http://schemas.openid.net/event/backchannel-logout" not in events: + raise InvalidClaimError("events") + # number of seconds a newly-generated client secret should be valid for CLIENT_SECRET_VALIDITY_SECONDS = 3600 @@ -1112,6 +1423,7 @@ class JwtClientSecret: logger.info( "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload ) + jwt = JsonWebToken(header["alg"]) self._cached_secret = jwt.encode(header, payload, self._key.key) self._cached_secret_replacement_time = ( expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS @@ -1126,9 +1438,6 @@ class UserAttributeDict(TypedDict): emails: List[str] -C = TypeVar("C") - - class OidcMappingProvider(Generic[C]): """A mapping provider maps a UserInfo object to user attributes. diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 5943f08e91..749d7e93b0 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -191,6 +191,7 @@ class SsoHandler: self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() self._error_template = hs.config.sso.sso_error_template self._bad_user_template = hs.config.sso.sso_auth_bad_user_template self._profile_handler = hs.get_profile_handler() @@ -1026,6 +1027,76 @@ class SsoHandler: return True + async def revoke_sessions_for_provider_session_id( + self, + auth_provider_id: str, + auth_provider_session_id: str, + expected_user_id: Optional[str] = None, + ) -> None: + """Revoke any devices and in-flight logins tied to a provider session. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + auth_provider_session_id: The session ID from the provider to logout + expected_user_id: The user we're expecting to logout. If set, it will ignore + sessions belonging to other users and log an error. + """ + # Invalidate any running user-mapping sessions + to_delete = [] + for session_id, session in self._username_mapping_sessions.items(): + if ( + session.auth_provider_id == auth_provider_id + and session.auth_provider_session_id == auth_provider_session_id + ): + to_delete.append(session_id) + + for session_id in to_delete: + logger.info("Revoking mapping session %s", session_id) + del self._username_mapping_sessions[session_id] + + # Invalidate any in-flight login tokens + await self._store.invalidate_login_tokens_by_session_id( + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + # Fetch any device(s) in the store associated with the session ID. + devices = await self._store.get_devices_by_auth_provider_session_id( + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + # We have no guarantee that all the devices of that session are for the same + # `user_id`. Hence, we have to iterate over the list of devices and log them out + # one by one. + for device in devices: + user_id = device["user_id"] + device_id = device["device_id"] + + # If the user_id associated with that device/session is not the one we got + # out of the `sub` claim, skip that device and show log an error. + if expected_user_id is not None and user_id != expected_user_id: + logger.error( + "Received a logout notification from SSO provider " + f"{auth_provider_id!r} for the user {expected_user_id!r}, but with " + f"a session ID ({auth_provider_session_id!r}) which belongs to " + f"{user_id!r}. This may happen when the SSO provider user mapper " + "uses something else than the standard attribute as mapping ID. " + "For OIDC providers, set `backchannel_logout_ignore_sub` to `true` " + "in the provider config if that is the case." + ) + continue + + logger.info( + "Logging out %r (device %r) via SSO (%r) logout notification (session %r).", + user_id, + device_id, + auth_provider_id, + auth_provider_session_id, + ) + await self._device_handler.delete_devices(user_id, [device_id]) + def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: """Extract the session ID from the cookie diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py index 81fec39659..e4b28ce3df 100644 --- a/synapse/rest/synapse/client/oidc/__init__.py +++ b/synapse/rest/synapse/client/oidc/__init__.py @@ -17,6 +17,9 @@ from typing import TYPE_CHECKING from twisted.web.resource import Resource +from synapse.rest.synapse.client.oidc.backchannel_logout_resource import ( + OIDCBackchannelLogoutResource, +) from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource if TYPE_CHECKING: @@ -29,6 +32,7 @@ class OIDCResource(Resource): def __init__(self, hs: "HomeServer"): Resource.__init__(self) self.putChild(b"callback", OIDCCallbackResource(hs)) + self.putChild(b"backchannel_logout", OIDCBackchannelLogoutResource(hs)) __all__ = ["OIDCResource"] diff --git a/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py new file mode 100644 index 0000000000..e07e76855a --- /dev/null +++ b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py @@ -0,0 +1,35 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from synapse.http.server import DirectServeJsonResource +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class OIDCBackchannelLogoutResource(DirectServeJsonResource): + isLeaf = 1 + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._oidc_handler = hs.get_oidc_handler() + + async def _async_render_POST(self, request: SynapseRequest) -> None: + await self._oidc_handler.handle_backchannel_logout(request) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 0255295317..5167089e03 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1920,6 +1920,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): self._clock.time_msec(), ) + async def invalidate_login_tokens_by_session_id( + self, auth_provider_id: str, auth_provider_session_id: str + ) -> None: + """Invalidate login tokens with the given IdP session ID. + + Args: + auth_provider_id: The SSO Identity Provider that the user authenticated with + to get this token + auth_provider_session_id: The session ID advertised by the SSO Identity + Provider + """ + await self.db_pool.simple_update( + table="login_tokens", + keyvalues={ + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + updatevalues={"used_ts": self._clock.time_msec()}, + desc="invalidate_login_tokens_by_session_id", + ) + @cached() async def is_guest(self, user_id: str) -> bool: res = await self.db_pool.simple_select_one_onecol( diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index ebf653d018..847294dc8e 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re from http import HTTPStatus from typing import Any, Dict, List, Optional, Tuple, Union @@ -21,7 +22,7 @@ from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType -from synapse.api.errors import Codes +from synapse.api.errors import Codes, SynapseError from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree @@ -32,8 +33,8 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC -from tests.rest.client.utils import TEST_OIDC_CONFIG -from tests.server import FakeChannel +from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER +from tests.server import FakeChannel, make_request from tests.unittest import override_config, skip_unless @@ -638,19 +639,6 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": refresh_token}, ) - def is_access_token_valid(self, access_token: str) -> bool: - """ - Checks whether an access token is valid, returning whether it is or not. - """ - code = self.make_request( - "GET", "/_matrix/client/v3/account/whoami", access_token=access_token - ).code - - # Either 200 or 401 is what we get back; anything else is a bug. - assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED} - - return code == HTTPStatus.OK - def test_login_issue_refresh_token(self) -> None: """ A login response should include a refresh_token only if asked. @@ -847,29 +835,37 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.reactor.advance(59.0) # Both tokens should still be valid. - self.assertTrue(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami(refreshable_access_token, expect_code=HTTPStatus.OK) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 61 s (just past 1 minute, the time of expiry) self.reactor.advance(2.0) # Only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 599 s (just shy of 10 minutes, the time of expiry) self.reactor.advance(599.0 - 61.0) # It's still the case that only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 601 s (just past 10 minutes, the time of expiry) self.reactor.advance(2.0) # Now neither token is valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami( + nonrefreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) @override_config( {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} @@ -1165,3 +1161,349 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # and no refresh token self.assertEqual(_table_length("access_tokens"), 0) self.assertEqual(_table_length("refresh_tokens"), 0) + + +def oidc_config( + id: str, with_localpart_template: bool, **kwargs: Any +) -> Dict[str, Any]: + """Sample OIDC provider config used in backchannel logout tests. + + Args: + id: IDP ID for this provider + with_localpart_template: Set to `true` to have a default localpart_template in + the `user_mapping_provider` config and skip the user mapping session + **kwargs: rest of the config + + Returns: + A dict suitable for the `oidc_config` or the `oidc_providers[]` parts of + the HS config + """ + config: Dict[str, Any] = { + "idp_id": id, + "idp_name": id, + "issuer": TEST_OIDC_ISSUER, + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["openid"], + } + + if with_localpart_template: + config["user_mapping_provider"] = { + "config": {"localpart_template": "{{ user.sub }}"} + } + else: + config["user_mapping_provider"] = {"config": {}} + + config.update(kwargs) + + return config + + +@skip_unless(HAS_OIDC, "Requires OIDC") +class OidcBackchannelLogoutTests(unittest.HomeserverTestCase): + servlets = [ + account.register_servlets, + login.register_servlets, + ] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + + # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns + # False, so synapse will see the requested uri as http://..., so using http in + # the public_baseurl stops Synapse trying to redirect to https. + config["public_baseurl"] = "http://synapse.test" + + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + resource_dict = super().create_resource_dict() + resource_dict.update(build_synapse_client_resource_tree(self.hs)) + return resource_dict + + def submit_logout_token(self, logout_token: str) -> FakeChannel: + return self.make_request( + "POST", + "/_synapse/client/oidc/backchannel_logout", + content=f"logout_token={logout_token}", + content_is_form=True, + ) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_simple_logout(self) -> None: + """ + Receiving a logout token should logout the user + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, first_grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + first_access_token: str = login_resp["access_token"] + self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK) + + login_resp, second_grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + second_access_token: str = login_resp["access_token"] + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + self.assertNotEqual(first_grant.sid, second_grant.sid) + self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"]) + + # Logging out of the first session + logout_token = fake_oidc_server.generate_logout_token(first_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED) + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # Logging out of the second session + logout_token = fake_oidc_server.generate_logout_token(second_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_logout_during_login(self) -> None: + """ + It should revoke login tokens when receiving a logout token + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + # Get an authentication, and logout before submitting the logout token + client_redirect_url = "https://x" + userinfo = {"sub": user} + channel, grant = self.helper.auth_via_oidc( + fake_oidc_server, + userinfo, + client_redirect_url, + with_sid=True, + ) + + # expect a confirmation page + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.text_body, + ) + assert m, channel.text_body + login_token = m.group(1) + + # Submit a logout + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + # Now try to exchange the login token + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + # It should have failed + self.assertEqual(channel.code, 403) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=False, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_logout_during_mapping(self) -> None: + """ + It should stop ongoing user mapping session when receiving a logout token + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + # Get an authentication, and logout before submitting the logout token + client_redirect_url = "https://x" + userinfo = {"sub": user} + channel, grant = self.helper.auth_via_oidc( + fake_oidc_server, + userinfo, + client_redirect_url, + with_sid=True, + ) + + # Expect a user mapping page + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) + + # We should have a user_mapping_session cookie + cookie_headers = channel.headers.getRawHeaders("Set-Cookie") + assert cookie_headers + cookies: Dict[str, str] = {} + for h in cookie_headers: + key, value = h.split(";")[0].split("=", maxsplit=1) + cookies[key] = value + + user_mapping_session_id = cookies["username_mapping_session"] + + # Getting that session should not raise + session = self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id) + self.assertIsNotNone(session) + + # Submit a logout + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + # Now it should raise + with self.assertRaises(SynapseError): + self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=False, + ) + ] + } + ) + def test_disabled(self) -> None: + """ + Receiving a logout token should do nothing if it is disabled in the config + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + access_token: str = login_resp["access_token"] + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + # Logging out shouldn't work + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 400) + + # And the token should still be valid + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_no_sid(self) -> None: + """ + Receiving a logout token without `sid` during the login should do nothing + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=False + ) + access_token: str = login_resp["access_token"] + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + # Logging out shouldn't work + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 400) + + # And the token should still be valid + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + @override_config( + { + "oidc_providers": [ + oidc_config( + "first", + issuer="https://first-issuer.com/", + with_localpart_template=True, + backchannel_logout_enabled=True, + ), + oidc_config( + "second", + issuer="https://second-issuer.com/", + with_localpart_template=True, + backchannel_logout_enabled=True, + ), + ] + } + ) + def test_multiple_providers(self) -> None: + """ + It should be able to distinguish login tokens from two different IdPs + """ + first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/") + second_server = self.helper.fake_oidc_server( + issuer="https://second-issuer.com/" + ) + user = "john" + + login_resp, first_grant = self.helper.login_via_oidc( + first_server, user, with_sid=True, idp_id="oidc-first" + ) + first_access_token: str = login_resp["access_token"] + self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK) + + login_resp, second_grant = self.helper.login_via_oidc( + second_server, user, with_sid=True, idp_id="oidc-second" + ) + second_access_token: str = login_resp["access_token"] + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # `sid` in the fake providers are generated by a counter, so the first grant of + # each provider should give the same SID + self.assertEqual(first_grant.sid, second_grant.sid) + self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"]) + + # Logging out of the first session + logout_token = first_server.generate_logout_token(first_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED) + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # Logging out of the second session + logout_token = second_server.generate_logout_token(second_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 967d229223..706399fae5 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -553,6 +553,34 @@ class RestHelper: return channel.json_body + def whoami( + self, + access_token: str, + expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK, + ) -> JsonDict: + """Perform a 'whoami' request, which can be a quick way to check for access + token validity + + Args: + access_token: The user token to use during the request + expect_code: The return code to expect from attempting the whoami request + """ + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + "account/whoami", + access_token=access_token, + ) + + assert channel.code == expect_code, "Exepcted: %d, got %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: """Create a ``FakeOidcServer``. @@ -572,6 +600,7 @@ class RestHelper: fake_server: FakeOidcServer, remote_user_id: str, with_sid: bool = False, + idp_id: Optional[str] = None, expected_status: int = 200, ) -> Tuple[JsonDict, FakeAuthorizationGrant]: """Log in (as a new user) via OIDC @@ -588,7 +617,11 @@ class RestHelper: client_redirect_url = "https://x" userinfo = {"sub": remote_user_id} channel, grant = self.auth_via_oidc( - fake_server, userinfo, client_redirect_url, with_sid=with_sid + fake_server, + userinfo, + client_redirect_url, + with_sid=with_sid, + idp_id=idp_id, ) # expect a confirmation page @@ -623,6 +656,7 @@ class RestHelper: client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, with_sid: bool = False, + idp_id: Optional[str] = None, ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Perform an OIDC authentication flow via a mock OIDC provider. @@ -648,6 +682,7 @@ class RestHelper: ui_auth_session_id: if set, we will perform a UI Auth flow. The session id of the UI auth. with_sid: if True, generates a random `sid` (OIDC session ID) + idp_id: if set, explicitely chooses one specific IDP Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. @@ -665,7 +700,9 @@ class RestHelper: oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) else: # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + oauth_uri = self.initiate_sso_login( + client_redirect_url, cookies, idp_id=idp_id + ) # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" @@ -742,7 +779,10 @@ class RestHelper: return channel, grant def initiate_sso_login( - self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] + self, + client_redirect_url: Optional[str], + cookies: MutableMapping[str, str], + idp_id: Optional[str] = None, ) -> str: """Make a request to the login-via-sso redirect endpoint, and return the target @@ -753,6 +793,7 @@ class RestHelper: client_redirect_url: the client redirect URL to pass to the login redirect endpoint cookies: any cookies returned will be added to this dict + idp_id: if set, explicitely chooses one specific IDP Returns: the URI that the client gets redirected to (ie, the SSO server) @@ -761,6 +802,12 @@ class RestHelper: if client_redirect_url: params["redirectUrl"] = client_redirect_url + uri = "/_matrix/client/r0/login/sso/redirect" + if idp_id is not None: + uri = f"{uri}/{idp_id}" + + uri = f"{uri}?{urllib.parse.urlencode(params)}" + # hit the redirect url (which should redirect back to the redirect url. This # is the easiest way of figuring out what the Host header ought to be set to # to keep Synapse happy. @@ -768,7 +815,7 @@ class RestHelper: self.hs.get_reactor(), self.site, "GET", - "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), + uri, ) assert channel.code == 302 diff --git a/tests/server.py b/tests/server.py index 8b1d186219..b1730fcc8d 100644 --- a/tests/server.py +++ b/tests/server.py @@ -362,6 +362,12 @@ def make_request( # Twisted expects to be at the end of the content when parsing the request. req.content.seek(0, SEEK_END) + # Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded + # bodies if the Content-Length header is missing + req.requestHeaders.addRawHeader( + b"Content-Length", str(len(content)).encode("ascii") + ) + if access_token: req.requestHeaders.addRawHeader( b"Authorization", b"Bearer " + access_token.encode("ascii") diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py index de134bbc89..1461d23ee8 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py @@ -51,6 +51,8 @@ class FakeOidcServer: get_userinfo_handler: Mock post_token_handler: Mock + sid_counter: int = 0 + def __init__(self, clock: Clock, issuer: str): from authlib.jose import ECKey, KeySet @@ -146,7 +148,7 @@ class FakeOidcServer: return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: - now = self._clock.time() + now = int(self._clock.time()) id_token = { **grant.userinfo, "iss": self.issuer, @@ -166,6 +168,26 @@ class FakeOidcServer: return self._sign(id_token) + def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str: + now = int(self._clock.time()) + logout_token = { + "iss": self.issuer, + "aud": grant.client_id, + "iat": now, + "jti": random_string(10), + "events": { + "http://schemas.openid.net/event/backchannel-logout": {}, + }, + } + + if grant.sid is not None: + logout_token["sid"] = grant.sid + + if "sub" in grant.userinfo: + logout_token["sub"] = grant.userinfo["sub"] + + return self._sign(logout_token) + def id_token_override(self, overrides: dict): """Temporarily patch the ID token generated by the token endpoint.""" return patch.object(self, "_id_token_overrides", overrides) @@ -183,7 +205,8 @@ class FakeOidcServer: code = random_string(10) sid = None if with_sid: - sid = random_string(10) + sid = str(self.sid_counter) + self.sid_counter += 1 grant = FakeAuthorizationGrant( userinfo=userinfo, -- cgit 1.4.1 From dbfc9b803ee32f7b31c2b5ccbc53a1bfcaa95983 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 31 Oct 2022 20:31:43 +0000 Subject: Fix dehydrated device REST checks (#14336) --- changelog.d/14336.bugfix | 1 + synapse/rest/client/devices.py | 5 ++--- tests/rest/client/test_devices.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14336.bugfix (limited to 'synapse') diff --git a/changelog.d/14336.bugfix b/changelog.d/14336.bugfix new file mode 100644 index 0000000000..d44ff1bbc7 --- /dev/null +++ b/changelog.d/14336.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70 where clients were unable to PUT new [dehydrated devices](https://github.com/matrix-org/matrix-spec-proposals/pull/2697). diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 90828c95c4..8f3cbd4ea2 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -231,7 +231,7 @@ class DehydratedDeviceServlet(RestServlet): } } - PUT /org.matrix.msc2697/dehydrated_device + PUT /org.matrix.msc2697.v2/dehydrated_device Content-Type: application/json { @@ -271,7 +271,6 @@ class DehydratedDeviceServlet(RestServlet): raise errors.NotFoundError("No dehydrated device available") class PutBody(RequestBodyModel): - device_id: StrictStr device_data: DehydratedDeviceDataModel initial_device_display_name: Optional[StrictStr] @@ -281,7 +280,7 @@ class DehydratedDeviceServlet(RestServlet): device_id = await self.device_handler.store_dehydrated_device( requester.user.to_string(), - submission.device_data, + submission.device_data.dict(), submission.initial_device_display_name, ) return 200, {"device_id": device_id} diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py index aa98222434..d80eea17d3 100644 --- a/tests/rest/client/test_devices.py +++ b/tests/rest/client/test_devices.py @@ -200,3 +200,37 @@ class DevicesTestCase(unittest.HomeserverTestCase): self.reactor.advance(43200) self.get_success(self.handler.get_device(user_id, "abc")) self.get_failure(self.handler.get_device(user_id, "def"), NotFoundError) + + +class DehydratedDeviceTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + devices.register_servlets, + ] + + def test_PUT(self) -> None: + """Sanity-check that we can PUT a dehydrated device. + + Detects https://github.com/matrix-org/synapse/issues/14334. + """ + alice = self.register_user("alice", "correcthorse") + token = self.login(alice, "correcthorse") + + # Have alice update their device list + channel = self.make_request( + "PUT", + "_matrix/client/unstable/org.matrix.msc2697.v2/dehydrated_device", + { + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device", + } + }, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + device_id = channel.json_body.get("device_id") + self.assertIsInstance(device_id, str) -- cgit 1.4.1 From b922b54b6143f13c0786a18fcbb5f55724ea72fc Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 1 Nov 2022 10:30:43 +0000 Subject: Fix type annotation causing import time error in the Complement forking launcher. (#14084) Co-authored-by: David Robertson --- changelog.d/14084.misc | 1 + synapse/app/complement_fork_starter.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14084.misc (limited to 'synapse') diff --git a/changelog.d/14084.misc b/changelog.d/14084.misc new file mode 100644 index 0000000000..988e55f437 --- /dev/null +++ b/changelog.d/14084.misc @@ -0,0 +1 @@ +Fix type annotation causing import time error in the Complement forking launcher. \ No newline at end of file diff --git a/synapse/app/complement_fork_starter.py b/synapse/app/complement_fork_starter.py index b22f315453..8c0f4a57e7 100644 --- a/synapse/app/complement_fork_starter.py +++ b/synapse/app/complement_fork_starter.py @@ -55,13 +55,13 @@ import os import signal import sys from types import FrameType -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional from twisted.internet.main import installReactor # a list of the original signal handlers, before we installed our custom ones. # We restore these in our child processes. -_original_signal_handlers: dict[int, Any] = {} +_original_signal_handlers: Dict[int, Any] = {} class ProxiedReactor: -- cgit 1.4.1 From 9473ebb9e7db9e3f71b341f72ae004db3a0144b8 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 1 Nov 2022 11:47:09 +0000 Subject: Revert "Fix event size checks (#13710)" This reverts commit fab495a9e1442d99e922367f65f41de5eaa488eb. As noted in https://github.com/matrix-org/synapse/pull/13710#issuecomment-1298396007: > We want to see this change land for the protocol's sake (and plan to un-revert it) but want to give this a little more time before releasing this. --- changelog.d/13710.bugfix | 1 - synapse/event_auth.py | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) delete mode 100644 changelog.d/13710.bugfix (limited to 'synapse') diff --git a/changelog.d/13710.bugfix b/changelog.d/13710.bugfix deleted file mode 100644 index 4c318d15f5..0000000000 --- a/changelog.d/13710.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where Synapse would count codepoints instead of bytes when validating the size of some fields. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 5036604036..bab31e33c5 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -342,15 +342,15 @@ def check_state_dependent_auth_rules( def _check_size_limits(event: "EventBase") -> None: - if len(event.user_id.encode("utf-8")) > 255: + if len(event.user_id) > 255: raise EventSizeError("'user_id' too large") - if len(event.room_id.encode("utf-8")) > 255: + if len(event.room_id) > 255: raise EventSizeError("'room_id' too large") - if event.is_state() and len(event.state_key.encode("utf-8")) > 255: + if event.is_state() and len(event.state_key) > 255: raise EventSizeError("'state_key' too large") - if len(event.type.encode("utf-8")) > 255: + if len(event.type) > 255: raise EventSizeError("'type' too large") - if len(event.event_id.encode("utf-8")) > 255: + if len(event.event_id) > 255: raise EventSizeError("'event_id' too large") if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE: raise EventSizeError("event too large") -- cgit 1.4.1 From 2bd7f3eeab1a4818359c9f585b660ff3f3d8bc6c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 1 Nov 2022 15:02:39 +0000 Subject: Allow PUT/GET of aliases during faster join (#14292) without blocking on full state. --- changelog.d/14292.bugfix | 1 + synapse/handlers/directory.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14292.bugfix (limited to 'synapse') diff --git a/changelog.d/14292.bugfix b/changelog.d/14292.bugfix new file mode 100644 index 0000000000..4ed92f5cf2 --- /dev/null +++ b/changelog.d/14292.bugfix @@ -0,0 +1 @@ +Faster joins: do not block creation of or queries for room aliases during the resync. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index d52ebada6b..2ea52257cb 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -85,7 +85,7 @@ class DirectoryHandler: # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - servers = await self._storage_controllers.state.get_current_hosts_in_room( + servers = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( room_id ) @@ -290,7 +290,7 @@ class DirectoryHandler: Codes.NOT_FOUND, ) - extra_servers = await self._storage_controllers.state.get_current_hosts_in_room( + extra_servers = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( room_id ) servers_set = set(extra_servers) | set(servers) -- cgit 1.4.1 From d4fac8a3e27ab3e133c5e5ac603c8d937a1fd86c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 1 Nov 2022 19:20:35 +0000 Subject: Fix typo in #13320 which could cause log spam (#14347) --- changelog.d/14347.bugfix | 1 + synapse/federation/federation_client.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14347.bugfix (limited to 'synapse') diff --git a/changelog.d/14347.bugfix b/changelog.d/14347.bugfix new file mode 100644 index 0000000000..91975757ae --- /dev/null +++ b/changelog.d/14347.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.64.0rc1 which could cause log spam when fetching events from other homeservers. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index fa225182be..c4c0bc7315 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -465,7 +465,7 @@ class FederationClient(FederationBase): pdu_attempts[destination] = now logger.info( - "get_pdu(event_id=): Failed to get PDU from %s because %s", + "get_pdu(event_id=%s): Failed to get PDU from %s because %s", event_id, destination, e, -- cgit 1.4.1 From 86c5a710d8b4212f8a8a668d7d4a79c0bb371508 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 3 Nov 2022 16:21:31 +0000 Subject: Implement MSC3912: Relation-based redactions (#14260) Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14260.feature | 1 + synapse/api/constants.py | 2 + synapse/config/experimental.py | 3 + synapse/handlers/message.py | 47 ++++- synapse/handlers/relations.py | 56 +++++- synapse/rest/client/room.py | 57 ++++-- synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/relations.py | 36 ++++ tests/rest/client/test_redactions.py | 273 +++++++++++++++++++++++++++- tests/rest/client/utils.py | 37 ++++ 10 files changed, 486 insertions(+), 28 deletions(-) create mode 100644 changelog.d/14260.feature (limited to 'synapse') diff --git a/changelog.d/14260.feature b/changelog.d/14260.feature new file mode 100644 index 0000000000..102dc7b3e0 --- /dev/null +++ b/changelog.d/14260.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3912](https://github.com/matrix-org/matrix-spec-proposals/pull/3912): Relation-based redactions. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 44c5ffc6a5..bc04a0755b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -125,6 +125,8 @@ class EventTypes: MSC2716_BATCH: Final = "org.matrix.msc2716.batch" MSC2716_MARKER: Final = "org.matrix.msc2716.marker" + Reaction: Final = "m.reaction" + class ToDeviceEventTypes: RoomKeyRequest: Final = "m.room_key_request" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d9bdd66d55..d4b71d1673 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -128,3 +128,6 @@ class ExperimentalConfig(Config): self.msc3886_endpoint: Optional[str] = experimental.get( "msc3886_endpoint", None ) + + # MSC3912: Relation-based redactions. + self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 468900a07f..4cf593cfdc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -877,6 +877,36 @@ class EventCreationHandler: return prev_event return None + async def get_event_from_transaction( + self, + requester: Requester, + txn_id: str, + room_id: str, + ) -> Optional[EventBase]: + """For the given transaction ID and room ID, check if there is a matching event. + If so, fetch it and return it. + + Args: + requester: The requester making the request in the context of which we want + to fetch the event. + txn_id: The transaction ID. + room_id: The room ID. + + Returns: + An event if one could be found, None otherwise. + """ + if requester.access_token_id: + existing_event_id = await self.store.get_event_id_from_transaction_id( + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, + ) + if existing_event_id: + return await self.store.get_event(existing_event_id) + + return None + async def create_and_send_nonmember_event( self, requester: Requester, @@ -956,18 +986,17 @@ class EventCreationHandler: # extremities to pile up, which in turn leads to state resolution # taking longer. async with self.limiter.queue(event_dict["room_id"]): - if txn_id and requester.access_token_id: - existing_event_id = await self.store.get_event_id_from_transaction_id( - event_dict["room_id"], - requester.user.to_string(), - requester.access_token_id, - txn_id, + if txn_id: + event = await self.get_event_from_transaction( + requester, txn_id, event_dict["room_id"] ) - if existing_event_id: - event = await self.store.get_event(existing_event_id) + if event: # we know it was persisted, so must have a stream ordering assert event.internal_metadata.stream_ordering - return event, event.internal_metadata.stream_ordering + return ( + event, + event.internal_metadata.stream_ordering, + ) event, context = await self.create_event( requester, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 0a0c6d938e..8e71dda970 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tup import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace @@ -75,6 +75,7 @@ class RelationsHandler: self._clock = hs.get_clock() self._event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._event_creation_handler = hs.get_event_creation_handler() async def get_relations( self, @@ -205,6 +206,59 @@ class RelationsHandler: return related_events, next_token + async def redact_events_related_to( + self, + requester: Requester, + event_id: str, + initial_redaction_event: EventBase, + relation_types: List[str], + ) -> None: + """Redacts all events related to the given event ID with one of the given + relation types. + + This method is expected to be called when redacting the event referred to by + the given event ID. + + If an event cannot be redacted (e.g. because of insufficient permissions), log + the error and try to redact the next one. + + Args: + requester: The requester to redact events on behalf of. + event_id: The event IDs to look and redact relations of. + initial_redaction_event: The redaction for the event referred to by + event_id. + relation_types: The types of relations to look for. + + Raises: + ShadowBanError if the requester is shadow-banned + """ + related_event_ids = ( + await self._main_store.get_all_relations_for_event_with_types( + event_id, relation_types + ) + ) + + for related_event_id in related_event_ids: + try: + await self._event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": initial_redaction_event.content, + "room_id": initial_redaction_event.room_id, + "sender": requester.user.to_string(), + "redacts": related_event_id, + }, + ratelimit=False, + ) + except SynapseError as e: + logger.warning( + "Failed to redact event %s (related to event %s): %s", + related_event_id, + event_id, + e.msg, + ) + async def get_annotations_for_event( self, event_id: str, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 01e5079963..91cb791139 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -52,6 +52,7 @@ from synapse.http.servlet import ( from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.state import StateFilter @@ -1029,6 +1030,8 @@ class RoomRedactEventRestServlet(TransactionRestServlet): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() + self._relation_handler = hs.get_relations_handler() + self._msc3912_enabled = hs.config.experimental.msc3912_enabled def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" @@ -1045,20 +1048,46 @@ class RoomRedactEventRestServlet(TransactionRestServlet): content = parse_json_object_from_request(request) try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) + with_relations = None + if self._msc3912_enabled and "org.matrix.msc3912.with_relations" in content: + with_relations = content["org.matrix.msc3912.with_relations"] + del content["org.matrix.msc3912.with_relations"] + + # Check if there's an existing event for this transaction now (even though + # create_and_send_nonmember_event also does it) because, if there's one, + # then we want to skip the call to redact_events_related_to. + event = None + if txn_id: + event = await self.event_creation_handler.get_event_from_transaction( + requester, txn_id, room_id + ) + + if event is None: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + + if with_relations: + run_as_background_process( + "redact_related_events", + self._relation_handler.redact_events_related_to, + requester=requester, + event_id=event_id, + initial_redaction_event=event, + relation_types=with_relations, + ) + event_id = event.event_id except ShadowBanError: event_id = "$" + random_string(43) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 9b1b72c68a..180a11ef88 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -119,6 +119,8 @@ class VersionsRestServlet(RestServlet): # Adds support for simple HTTP rendezvous as per MSC3886 "org.matrix.msc3886": self.config.experimental.msc3886_endpoint is not None, + # Adds support for relation-based redactions as per MSC3912. + "org.matrix.msc3912": self.config.experimental.msc3912_enabled, }, }, ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c022510e76..ca431002c8 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -295,6 +295,42 @@ class RelationsWorkerStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + async def get_all_relations_for_event_with_types( + self, + event_id: str, + relation_types: List[str], + ) -> List[str]: + """Get the event IDs of all events that have a relation to the given event with + one of the given relation types. + + Args: + event_id: The event for which to look for related events. + relation_types: The types of relations to look for. + + Returns: + A list of the IDs of the events that relate to the given event with one of + the given relation types. + """ + + def get_all_relation_ids_for_event_with_types_txn( + txn: LoggingTransaction, + ) -> List[str]: + rows = self.db_pool.simple_select_many_txn( + txn=txn, + table="event_relations", + column="relation_type", + iterable=relation_types, + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ) + + return [row["event_id"] for row in rows] + + return await self.db_pool.runInteraction( + desc="get_all_relation_ids_for_event_with_types", + func=get_all_relation_ids_for_event_with_types_txn, + ) + async def event_includes_relation(self, event_id: str) -> bool: """Check if the given event relates to another event. diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index be4c67d68e..5dfe44defb 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, room, sync from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class RedactionsTestCase(HomeserverTestCase): @@ -67,7 +68,12 @@ class RedactionsTestCase(HomeserverTestCase): ) def _redact_event( - self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + self, + access_token: str, + room_id: str, + event_id: str, + expect_code: int = 200, + with_relations: Optional[List[str]] = None, ) -> JsonDict: """Helper function to send a redaction event. @@ -75,7 +81,13 @@ class RedactionsTestCase(HomeserverTestCase): """ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) - channel = self.make_request("POST", path, content={}, access_token=access_token) + request_content = {} + if with_relations: + request_content["org.matrix.msc3912.with_relations"] = with_relations + + channel = self.make_request( + "POST", path, request_content, access_token=access_token + ) self.assertEqual(channel.code, expect_code) return channel.json_body @@ -201,3 +213,256 @@ class RedactionsTestCase(HomeserverTestCase): # These should all succeed, even though this would be denied by # the standard message ratelimiter self._redact_event(self.mod_access_token, self.room_id, msg_id) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations(self) -> None: + """Tests that we can redact the relations of an event at the same time as the + event itself. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "hello"}, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send an edit to this root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": " * hello world", + "m.new_content": { + "body": "hello world", + "msgtype": "m.text", + }, + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.REPLACE, + }, + "msgtype": "m.text", + }, + tok=self.mod_access_token, + ) + edit_event_id = res["event_id"] + + # Also send a threaded message whose root is the same as the edit's. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Also send a reaction, again with the same root. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Reaction, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": root_event_id, + "key": "👍", + } + }, + tok=self.mod_access_token, + ) + reaction_event_id = res["event_id"] + + # Redact the root event, specifying that we also want to delete events that + # relate to it with m.replace. + self._redact_event( + self.mod_access_token, + self.room_id, + root_event_id, + with_relations=[ + RelationTypes.REPLACE, + RelationTypes.THREAD, + ], + ) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the edit got redacted. + event_dict = self.helper.get_event( + self.room_id, edit_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the threaded message got redacted. + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the reaction did not get redacted. + event_dict = self.helper.get_event( + self.room_id, reaction_event_id, self.mod_access_token + ) + self.assertNotIn("redacted_because", event_dict, event_dict) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_no_perms(self) -> None: + """Tests that, when redacting a message along with its relations, if not all + the related messages can be redacted because of insufficient permissions, the + server still redacts all the ones that can be. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.other_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message, this one from the moderator. We do this for the + # first message with the m.thread relation (and not the last one) to ensure + # that, when the server fails to redact it, it doesn't stop there, and it + # instead goes on to redact the other one. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + first_threaded_event_id = res["event_id"] + + # Send a second threaded message, this time from the user who'll perform the + # redaction. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 2", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.other_access_token, + ) + second_threaded_event_id = res["event_id"] + + # Redact the thread's root, and request that all threaded messages are also + # redacted. Send that request from the non-mod user, so that the first threaded + # event cannot be redacted. + self._redact_event( + self.other_access_token, + self.room_id, + root_event_id, + with_relations=[RelationTypes.THREAD], + ) + + # Check that the thread root got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the last message in the thread got redacted, despite failing to + # redact the one before it. + event_dict = self.helper.get_event( + self.room_id, second_threaded_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the message that was sent into the tread by the mod user is not + # redacted. + event_dict = self.helper.get_event( + self.room_id, first_threaded_event_id, self.other_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("message 1", event_dict["content"]["body"]) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_txn_id_reuse(self) -> None: + """Tests that redacting a message using a transaction ID, then reusing the same + transaction ID but providing an additional list of relations to redact, is + effectively a no-op. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "I'm in a thread!", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Send a first redaction request which redacts only the root event. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Send a second redaction request which redacts the root event as well as + # threaded messages. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={"org.matrix.msc3912.with_relations": [RelationTypes.THREAD]}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict) + + # Check that the threaded message didn't get redacted (since that wasn't part of + # the original redaction). + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("I'm in a thread!", event_dict["content"]["body"]) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 706399fae5..8d6f2b6ff9 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -410,6 +410,43 @@ class RestHelper: return channel.json_body + def get_event( + self, + room_id: str, + event_id: str, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: + """Request a specific event from the server. + + Args: + room_id: the room in which the event was sent. + event_id: the event's ID. + tok: the token to request the event with. + expect_code: the expected HTTP status for the response. + + Returns: + The event as a dict. + """ + path = f"/_matrix/client/v3/rooms/{room_id}/event/{event_id}" + if tok: + path = path + f"?access_token={tok}" + + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + path, + ) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def _read_write_state( self, room_id: str, -- cgit 1.4.1