From 4c2096599c9780290703e14df63963e77d058dda Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 21 Jan 2022 08:38:36 +0000 Subject: Make the `get_global_account_data_by_type_for_user` cache be a tree-cache whose key is prefixed with the user ID (#11788) --- synapse/rest/client/account_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index d1badbdf3b..58b8adbd32 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -66,7 +66,7 @@ class AccountDataServlet(RestServlet): raise AuthError(403, "Cannot get account data for other users.") event = await self.store.get_global_account_data_by_type_for_user( - account_data_type, user_id + user_id, account_data_type ) if event is None: -- cgit 1.5.1 From b784299cbc121d27d7dadd0a4a96f4657244a4e9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Jan 2022 05:31:31 -0500 Subject: Do not try to serialize raw aggregations dict. (#11791) --- changelog.d/11612.bugfix | 1 + changelog.d/11612.misc | 1 - changelog.d/11791.bugfix | 1 + synapse/events/utils.py | 4 +- synapse/rest/admin/rooms.py | 13 ++--- synapse/rest/client/room.py | 11 ++-- tests/rest/client/test_relations.py | 108 ++++++++++++++++++++++++------------ 7 files changed, 85 insertions(+), 54 deletions(-) create mode 100644 changelog.d/11612.bugfix delete mode 100644 changelog.d/11612.misc create mode 100644 changelog.d/11791.bugfix (limited to 'synapse/rest/client') diff --git a/changelog.d/11612.bugfix b/changelog.d/11612.bugfix new file mode 100644 index 0000000000..842f6892fd --- /dev/null +++ b/changelog.d/11612.bugfix @@ -0,0 +1 @@ +Include the bundled aggregations in the `/sync` response, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/changelog.d/11612.misc b/changelog.d/11612.misc deleted file mode 100644 index 2d886169c5..0000000000 --- a/changelog.d/11612.misc +++ /dev/null @@ -1 +0,0 @@ -Avoid database access in the JSON serialization process. diff --git a/changelog.d/11791.bugfix b/changelog.d/11791.bugfix new file mode 100644 index 0000000000..842f6892fd --- /dev/null +++ b/changelog.d/11791.bugfix @@ -0,0 +1 @@ +Include the bundled aggregations in the `/sync` response, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/synapse/events/utils.py b/synapse/events/utils.py index de0e0c1731..918adeecf8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -402,7 +402,7 @@ class EventClientSerializer: if bundle_aggregations: event_aggregations = bundle_aggregations.get(event.event_id) if event_aggregations: - self._injected_bundled_aggregations( + self._inject_bundled_aggregations( event, time_now, bundle_aggregations[event.event_id], @@ -411,7 +411,7 @@ class EventClientSerializer: return serialized_event - def _injected_bundled_aggregations( + def _inject_bundled_aggregations( self, event: EventBase, time_now: int, diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 2e714ac87b..efe25fe7eb 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -744,20 +744,15 @@ class RoomEventContextServlet(RestServlet): ) time_now = self.clock.time_msec() + aggregations = results.pop("aggregations", None) results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], - time_now, - bundle_aggregations=results["aggregations"], + results["events_before"], time_now, bundle_aggregations=aggregations ) results["event"] = self._event_serializer.serialize_event( - results["event"], - time_now, - bundle_aggregations=results["aggregations"], + results["event"], time_now, bundle_aggregations=aggregations ) results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], - time_now, - bundle_aggregations=results["aggregations"], + results["events_after"], time_now, bundle_aggregations=aggregations ) results["state"] = self._event_serializer.serialize_events( results["state"], time_now diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 31fd329a38..90bb9142a0 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -714,18 +714,15 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() + aggregations = results.pop("aggregations", None) results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], - time_now, - bundle_aggregations=results["aggregations"], + results["events_before"], time_now, bundle_aggregations=aggregations ) results["event"] = self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=results["aggregations"] + results["event"], time_now, bundle_aggregations=aggregations ) results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], - time_now, - bundle_aggregations=results["aggregations"], + results["events_after"], time_now, bundle_aggregations=aggregations ) results["state"] = self._event_serializer.serialize_events( results["state"], time_now diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 4b20ab0e3e..c9b220e73d 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -21,6 +21,7 @@ from unittest.mock import patch from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync +from synapse.types import JsonDict from tests import unittest from tests.server import FakeChannel @@ -454,7 +455,14 @@ class RelationsTestCase(unittest.HomeserverTestCase): @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_bundled_aggregations(self): - """Test that annotations, references, and threads get correctly bundled.""" + """ + Test that annotations, references, and threads get correctly bundled. + + Note that this doesn't test against /relations since only thread relations + get bundled via that API. See test_aggregation_get_event_for_thread. + + See test_edit for a similar test for edits. + """ # Setup by sending a variety of relations. channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -482,12 +490,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] - def assert_bundle(actual): + def assert_bundle(event_json: JsonDict) -> None: """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") # Ensure the fields are as expected. self.assertCountEqual( - actual.keys(), + relations_dict.keys(), ( RelationTypes.ANNOTATION, RelationTypes.REFERENCE, @@ -503,20 +512,20 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"type": "m.reaction", "key": "b", "count": 1}, ] }, - actual[RelationTypes.ANNOTATION], + relations_dict[RelationTypes.ANNOTATION], ) self.assertEquals( {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - actual[RelationTypes.REFERENCE], + relations_dict[RelationTypes.REFERENCE], ) self.assertEquals( 2, - actual[RelationTypes.THREAD].get("count"), + relations_dict[RelationTypes.THREAD].get("count"), ) self.assertTrue( - actual[RelationTypes.THREAD].get("current_user_participated") + relations_dict[RelationTypes.THREAD].get("current_user_participated") ) # The latest thread event has some fields that don't matter. self.assert_dict( @@ -533,20 +542,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): "type": "m.room.test", "user_id": self.user_id, }, - actual[RelationTypes.THREAD].get("latest_event"), + relations_dict[RelationTypes.THREAD].get("latest_event"), ) - def _find_and_assert_event(events): - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - break - else: - raise AssertionError(f"Event {self.parent_id} not found in chunk") - assert_bundle(event["unsigned"].get("m.relations")) - # Request the event directly. channel = self.make_request( "GET", @@ -554,7 +552,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["unsigned"].get("m.relations")) + assert_bundle(channel.json_body) # Request the room messages. channel = self.make_request( @@ -563,7 +561,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - _find_and_assert_event(channel.json_body["chunk"]) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) # Request the room context. channel = self.make_request( @@ -572,17 +570,14 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) + assert_bundle(channel.json_body["event"]) # Request sync. channel = self.make_request("GET", "/sync", access_token=self.user_token) self.assertEquals(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) - _find_and_assert_event(room_timeline["events"]) - - # Note that /relations is tested separately in test_aggregation_get_event_for_thread - # since it needs different data configured. + self._find_event_in_chunk(room_timeline["events"]) def test_aggregation_get_event_for_annotation(self): """Test that annotations do not get bundled aggregations included @@ -777,25 +772,58 @@ class RelationsTestCase(unittest.HomeserverTestCase): edit_event_id = channel.json_body["event_id"] + def assert_bundle(event_json: JsonDict) -> None: + """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), + f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["content"], new_body) + assert_bundle(channel.json_body) - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]) - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + # Request sync, but limit the timeline so it becomes limited (and includes + # bundled aggregations). + filter = urllib.parse.quote_plus( + '{"room": {"timeline": {"limit": 2}}}'.encode() + ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token ) + self.assertEquals(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) def test_multi_edit(self): """Test that multiple edits, including attempts by people who @@ -1102,6 +1130,16 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(channel.json_body["chunk"], []) + def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + return event + + raise AssertionError(f"Event {self.parent_id} not found in chunk") + def _send_relation( self, relation_type: str, -- cgit 1.5.1 From 95b3f952fa43e51feae166fa1678761c5e32d900 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 26 Jan 2022 12:02:54 +0000 Subject: Add a config flag to inhibit `M_USER_IN_USE` during registration (#11743) This is mostly motivated by the tchap use case, where usernames are automatically generated from the user's email address (in a way that allows figuring out the email address from the username). Therefore, it's an issue if we respond to requests on /register and /register/available with M_USER_IN_USE, because it can potentially leak email addresses (which include the user's real name and place of work). This commit adds a flag to inhibit the M_USER_IN_USE errors that are raised both by /register/available, and when providing a username early into the registration process. This error will still be raised if the user completes the registration process but the username conflicts. This is particularly useful when using modules (https://github.com/matrix-org/synapse/pull/11790 adds a module callback to set the username of users at registration) or SSO, since they can ensure the username is unique. More context is available in the PR that introduced this behaviour to synapse-dinsic: matrix-org/synapse-dinsic#48 - as well as the issue in the matrix-dinsic repo: matrix-org/matrix-dinsic#476 --- changelog.d/11743.feature | 1 + docs/sample_config.yaml | 10 ++++++++++ synapse/config/registration.py | 12 +++++++++++ synapse/handlers/register.py | 26 +++++++++++++----------- synapse/rest/client/register.py | 11 ++++++++++ tests/rest/client/test_register.py | 41 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 89 insertions(+), 12 deletions(-) create mode 100644 changelog.d/11743.feature (limited to 'synapse/rest/client') diff --git a/changelog.d/11743.feature b/changelog.d/11743.feature new file mode 100644 index 0000000000..9809f48b96 --- /dev/null +++ b/changelog.d/11743.feature @@ -0,0 +1 @@ +Add a config flag to inhibit M_USER_IN_USE during registration. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 1b86d0295d..b38e6d6c88 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1428,6 +1428,16 @@ account_threepid_delegates: # #auto_join_rooms_for_guests: false +# Whether to inhibit errors raised when registering a new account if the user ID +# already exists. If turned on, that requests to /register/available will always +# show a user ID as available, and Synapse won't raise an error when starting +# a registration with a user ID that already exists. However, Synapse will still +# raise an error if the registration completes and the username conflicts. +# +# Defaults to false. +# +#inhibit_user_in_use_error: true + ## Metrics ### diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 7a059c6dec..ea9b50fe97 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -190,6 +190,8 @@ class RegistrationConfig(Config): # The success template used during fallback auth. self.fallback_success_template = self.read_template("auth_success.html") + self.inhibit_user_in_use_error = config.get("inhibit_user_in_use_error", False) + def generate_config_section(self, generate_secrets=False, **kwargs): if generate_secrets: registration_shared_secret = 'registration_shared_secret: "%s"' % ( @@ -446,6 +448,16 @@ class RegistrationConfig(Config): # Defaults to true. # #auto_join_rooms_for_guests: false + + # Whether to inhibit errors raised when registering a new account if the user ID + # already exists. If turned on, that requests to /register/available will always + # show a user ID as available, and Synapse won't raise an error when starting + # a registration with a user ID that already exists. However, Synapse will still + # raise an error if the registration completes and the username conflicts. + # + # Defaults to false. + # + #inhibit_user_in_use_error: true """ % locals() ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f08a516a75..a719d5eef3 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -132,6 +132,7 @@ class RegistrationHandler: localpart: str, guest_access_token: Optional[str] = None, assigned_user_id: Optional[str] = None, + inhibit_user_in_use_error: bool = False, ) -> None: if types.contains_invalid_mxid_characters(localpart): raise SynapseError( @@ -171,21 +172,22 @@ class RegistrationHandler: users = await self.store.get_users_by_id_case_insensitive(user_id) if users: - if not guest_access_token: + if not inhibit_user_in_use_error and not guest_access_token: raise SynapseError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) - user_data = await self.auth.get_user_by_access_token(guest_access_token) - if ( - not user_data.is_guest - or UserID.from_string(user_data.user_id).localpart != localpart - ): - raise AuthError( - 403, - "Cannot register taken user ID without valid guest " - "credentials for that user.", - errcode=Codes.FORBIDDEN, - ) + if guest_access_token: + user_data = await self.auth.get_user_by_access_token(guest_access_token) + if ( + not user_data.is_guest + or UserID.from_string(user_data.user_id).localpart != localpart + ): + raise AuthError( + 403, + "Cannot register taken user ID without valid guest " + "credentials for that user.", + errcode=Codes.FORBIDDEN, + ) if guest_access_token is None: try: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 8b56c76aed..c59dae7c03 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -339,12 +339,19 @@ class UsernameAvailabilityRestServlet(RestServlet): ), ) + self.inhibit_user_in_use_error = ( + hs.config.registration.inhibit_user_in_use_error + ) + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) + if self.inhibit_user_in_use_error: + return 200, {"available": True} + ip = request.getClientIP() with self.ratelimiter.ratelimit(ip) as wait_deferred: await wait_deferred @@ -422,6 +429,9 @@ class RegisterRestServlet(RestServlet): self._refresh_tokens_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None ) + self._inhibit_user_in_use_error = ( + hs.config.registration.inhibit_user_in_use_error + ) self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -564,6 +574,7 @@ class RegisterRestServlet(RestServlet): desired_username, guest_access_token=guest_access_token, assigned_user_id=registered_user_id, + inhibit_user_in_use_error=self._inhibit_user_in_use_error, ) # Check if the user-interactive authentication flows are complete, if diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 6e7c0f11df..407dd32a73 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -726,6 +726,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) + @override_config( + { + "inhibit_user_in_use_error": True, + } + ) + def test_inhibit_user_in_use_error(self): + """Tests that the 'inhibit_user_in_use_error' configuration flag behaves + correctly. + """ + username = "arthur" + + # Manually register the user, so we know the test isn't passing because of a lack + # of clashing. + reg_handler = self.hs.get_registration_handler() + self.get_success(reg_handler.register_user(username)) + + # Check that /available correctly ignores the username provided despite the + # username being already registered. + channel = self.make_request("GET", "register/available?username=" + username) + self.assertEquals(200, channel.code, channel.result) + + # Test that when starting a UIA registration flow the request doesn't fail because + # of a conflicting username + channel = self.make_request( + "POST", + "register", + {"username": username, "type": "m.login.password", "password": "foo"}, + ) + self.assertEqual(channel.code, 401) + self.assertIn("session", channel.json_body) + + # Test that finishing the registration fails because of a conflicting username. + session = channel.json_body["session"] + channel = self.make_request( + "POST", + "register", + {"auth": {"session": session, "type": LoginType.DUMMY}}, + ) + self.assertEqual(channel.code, 400, channel.json_body) + self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE) + class AccountValidityTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 2897fb6b4fb8bdaea0e919233d5ccaf5dea12742 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 Jan 2022 08:27:04 -0500 Subject: Improvements to bundling aggregations. (#11815) This is some odds and ends found during the review of #11791 and while continuing to work in this code: * Return attrs classes instead of dictionaries from some methods to improve type safety. * Call `get_bundled_aggregations` fewer times. * Adds a missing assertion in the tests. * Do not return empty bundled aggregations for an event (preferring to not include the bundle at all, as the docstring states). --- changelog.d/11815.misc | 1 + synapse/events/utils.py | 57 ++++++++++++++------- synapse/handlers/room.py | 77 +++++++++++++++-------------- synapse/handlers/search.py | 45 ++++++++--------- synapse/handlers/sync.py | 3 +- synapse/push/mailer.py | 2 +- synapse/rest/admin/rooms.py | 39 +++++++++------ synapse/rest/client/room.py | 39 +++++++++------ synapse/rest/client/sync.py | 3 +- synapse/storage/databases/main/relations.py | 61 ++++++++++++++--------- synapse/storage/databases/main/stream.py | 22 ++++++--- tests/rest/client/test_relations.py | 2 +- 12 files changed, 212 insertions(+), 139 deletions(-) create mode 100644 changelog.d/11815.misc (limited to 'synapse/rest/client') diff --git a/changelog.d/11815.misc b/changelog.d/11815.misc new file mode 100644 index 0000000000..83aa6d6eb0 --- /dev/null +++ b/changelog.d/11815.misc @@ -0,0 +1 @@ +Improve type safety of bundled aggregations code. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 918adeecf8..243696b357 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -14,7 +14,17 @@ # limitations under the License. import collections.abc import re -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Union, +) from frozendict import frozendict @@ -26,6 +36,10 @@ from synapse.util.frozenutils import unfreeze from . import EventBase +if TYPE_CHECKING: + from synapse.storage.databases.main.relations import BundledAggregations + + # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # (? JsonDict: """Serializes a single event. @@ -415,7 +429,7 @@ class EventClientSerializer: self, event: EventBase, time_now: int, - aggregations: JsonDict, + aggregations: "BundledAggregations", serialized_event: JsonDict, ) -> None: """Potentially injects bundled aggregations into the unsigned portion of the serialized event. @@ -427,13 +441,18 @@ class EventClientSerializer: serialized_event: The serialized event which may be modified. """ - # Make a copy in-case the object is cached. - aggregations = aggregations.copy() + serialized_aggregations = {} + + if aggregations.annotations: + serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations + + if aggregations.references: + serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references - if RelationTypes.REPLACE in aggregations: + if aggregations.replace: # If there is an edit replace the content, preserving existing # relations. - edit = aggregations[RelationTypes.REPLACE] + edit = aggregations.replace # Ensure we take copies of the edit content, otherwise we risk modifying # the original event. @@ -451,24 +470,28 @@ class EventClientSerializer: else: serialized_event["content"].pop("m.relates_to", None) - aggregations[RelationTypes.REPLACE] = { + serialized_aggregations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, "sender": edit.sender, } # If this event is the start of a thread, include a summary of the replies. - if RelationTypes.THREAD in aggregations: - # Serialize the latest thread event. - latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"] - - # Don't bundle aggregations as this could recurse forever. - aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event( - latest_thread_event, time_now, bundle_aggregations=None - ) + if aggregations.thread: + serialized_aggregations[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. + "latest_event": self.serialize_event( + aggregations.thread.latest_event, time_now, bundle_aggregations=None + ), + "count": aggregations.thread.count, + "current_user_participated": aggregations.thread.current_user_participated, + } # Include the bundled aggregations in the event. - serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations) + if serialized_aggregations: + serialized_event["unsigned"].setdefault("m.relations", {}).update( + serialized_aggregations + ) def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f963078e59..1420d67729 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -30,6 +30,7 @@ from typing import ( Tuple, ) +import attr from typing_extensions import TypedDict from synapse.api.constants import ( @@ -60,6 +61,7 @@ from synapse.events.utils import copy_power_levels_contents from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.federation import get_domains_from_state from synapse.rest.admin._base import assert_user_is_admin +from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( @@ -90,6 +92,17 @@ id_server_scheme = "https://" FIVE_MINUTES_IN_MS = 5 * 60 * 1000 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventContext: + events_before: List[EventBase] + event: EventBase + events_after: List[EventBase] + state: List[EventBase] + aggregations: Dict[str, BundledAggregations] + start: str + end: str + + class RoomCreationHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -1119,7 +1132,7 @@ class RoomContextHandler: limit: int, event_filter: Optional[Filter], use_admin_priviledge: bool = False, - ) -> Optional[JsonDict]: + ) -> Optional[EventContext]: """Retrieves events, pagination tokens and state around a given event in a room. @@ -1167,38 +1180,28 @@ class RoomContextHandler: results = await self.store.get_events_around( room_id, event_id, before_limit, after_limit, event_filter ) + events_before = results.events_before + events_after = results.events_after if event_filter: - results["events_before"] = await event_filter.filter( - results["events_before"] - ) - results["events_after"] = await event_filter.filter(results["events_after"]) + events_before = await event_filter.filter(events_before) + events_after = await event_filter.filter(events_after) - results["events_before"] = await filter_evts(results["events_before"]) - results["events_after"] = await filter_evts(results["events_after"]) + events_before = await filter_evts(events_before) + events_after = await filter_evts(events_after) # filter_evts can return a pruned event in case the user is allowed to see that # there's something there but not see the content, so use the event that's in # `filtered` rather than the event we retrieved from the datastore. - results["event"] = filtered[0] + event = filtered[0] # Fetch the aggregations. aggregations = await self.store.get_bundled_aggregations( - [results["event"]], user.to_string() + itertools.chain(events_before, (event,), events_after), + user.to_string(), ) - aggregations.update( - await self.store.get_bundled_aggregations( - results["events_before"], user.to_string() - ) - ) - aggregations.update( - await self.store.get_bundled_aggregations( - results["events_after"], user.to_string() - ) - ) - results["aggregations"] = aggregations - if results["events_after"]: - last_event_id = results["events_after"][-1].event_id + if events_after: + last_event_id = events_after[-1].event_id else: last_event_id = event_id @@ -1206,9 +1209,9 @@ class RoomContextHandler: state_filter = StateFilter.from_lazy_load_member_list( ev.sender for ev in itertools.chain( - results["events_before"], - (results["event"],), - results["events_after"], + events_before, + (event,), + events_after, ) ) else: @@ -1226,21 +1229,23 @@ class RoomContextHandler: if event_filter: state_events = await event_filter.filter(state_events) - results["state"] = await filter_evts(state_events) - # We use a dummy token here as we only care about the room portion of # the token, which we replace. token = StreamToken.START - results["start"] = await token.copy_and_replace( - "room_key", results["start"] - ).to_string(self.store) - - results["end"] = await token.copy_and_replace( - "room_key", results["end"] - ).to_string(self.store) - - return results + return EventContext( + events_before=events_before, + event=event, + events_after=events_after, + state=await filter_evts(state_events), + aggregations=aggregations, + start=await token.copy_and_replace("room_key", results.start).to_string( + self.store + ), + end=await token.copy_and_replace("room_key", results.end).to_string( + self.store + ), + ) class TimestampLookupHandler: diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 0b153a6822..02bb5ae72f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -361,36 +361,37 @@ class SearchHandler: logger.info( "Context for search returned %d and %d events", - len(res["events_before"]), - len(res["events_after"]), + len(res.events_before), + len(res.events_after), ) - res["events_before"] = await filter_events_for_client( - self.storage, user.to_string(), res["events_before"] + events_before = await filter_events_for_client( + self.storage, user.to_string(), res.events_before ) - res["events_after"] = await filter_events_for_client( - self.storage, user.to_string(), res["events_after"] + events_after = await filter_events_for_client( + self.storage, user.to_string(), res.events_after ) - res["start"] = await now_token.copy_and_replace( - "room_key", res["start"] - ).to_string(self.store) - - res["end"] = await now_token.copy_and_replace( - "room_key", res["end"] - ).to_string(self.store) + context = { + "events_before": events_before, + "events_after": events_after, + "start": await now_token.copy_and_replace( + "room_key", res.start + ).to_string(self.store), + "end": await now_token.copy_and_replace( + "room_key", res.end + ).to_string(self.store), + } if include_profile: senders = { ev.sender - for ev in itertools.chain( - res["events_before"], [event], res["events_after"] - ) + for ev in itertools.chain(events_before, [event], events_after) } - if res["events_after"]: - last_event_id = res["events_after"][-1].event_id + if events_after: + last_event_id = events_after[-1].event_id else: last_event_id = event.event_id @@ -402,7 +403,7 @@ class SearchHandler: last_event_id, state_filter ) - res["profile_info"] = { + context["profile_info"] = { s.state_key: { "displayname": s.content.get("displayname", None), "avatar_url": s.content.get("avatar_url", None), @@ -411,7 +412,7 @@ class SearchHandler: if s.type == EventTypes.Member and s.state_key in senders } - contexts[event.event_id] = res + contexts[event.event_id] = context else: contexts = {} @@ -421,10 +422,10 @@ class SearchHandler: for context in contexts.values(): context["events_before"] = self._event_serializer.serialize_events( - context["events_before"], time_now + context["events_before"], time_now # type: ignore[arg-type] ) context["events_after"] = self._event_serializer.serialize_events( - context["events_after"], time_now + context["events_after"], time_now # type: ignore[arg-type] ) state_results = {} diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7e2a892b63..c72ed7c290 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -37,6 +37,7 @@ from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -100,7 +101,7 @@ class TimelineBatch: limited: bool # A mapping of event ID to the bundled aggregations for the above events. # This is only calculated if limited is true. - bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None + bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index dadfc57413..3df8452eec 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -455,7 +455,7 @@ class Mailer: } the_events = await filter_events_for_client( - self.storage, user_id, results["events_before"] + self.storage, user_id, results.events_before ) the_events.append(notif_event) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index efe25fe7eb..5b706efbcf 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -729,7 +729,7 @@ class RoomEventContextServlet(RestServlet): else: event_filter = None - results = await self.room_context_handler.get_event_context( + event_context = await self.room_context_handler.get_event_context( requester, room_id, event_id, @@ -738,25 +738,34 @@ class RoomEventContextServlet(RestServlet): use_admin_priviledge=True, ) - if not results: + if not event_context: raise SynapseError( HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND ) time_now = self.clock.time_msec() - aggregations = results.pop("aggregations", None) - results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], time_now, bundle_aggregations=aggregations - ) - results["event"] = self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=aggregations - ) - results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], time_now, bundle_aggregations=aggregations - ) - results["state"] = self._event_serializer.serialize_events( - results["state"], time_now - ) + results = { + "events_before": self._event_serializer.serialize_events( + event_context.events_before, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "event": self._event_serializer.serialize_event( + event_context.event, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "events_after": self._event_serializer.serialize_events( + event_context.events_after, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "state": self._event_serializer.serialize_events( + event_context.state, time_now + ), + "start": event_context.start, + "end": event_context.end, + } return HTTPStatus.OK, results diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 90bb9142a0..90355e44b2 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -706,27 +706,36 @@ class RoomEventContextServlet(RestServlet): else: event_filter = None - results = await self.room_context_handler.get_event_context( + event_context = await self.room_context_handler.get_event_context( requester, room_id, event_id, limit, event_filter ) - if not results: + if not event_context: raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - aggregations = results.pop("aggregations", None) - results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], time_now, bundle_aggregations=aggregations - ) - results["event"] = self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=aggregations - ) - results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], time_now, bundle_aggregations=aggregations - ) - results["state"] = self._event_serializer.serialize_events( - results["state"], time_now - ) + results = { + "events_before": self._event_serializer.serialize_events( + event_context.events_before, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "event": self._event_serializer.serialize_event( + event_context.event, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "events_after": self._event_serializer.serialize_events( + event_context.events_after, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "state": self._event_serializer.serialize_events( + event_context.state, time_now + ), + "start": event_context.start, + "end": event_context.end, + } return 200, results diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index d20ae1421e..f9615da525 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -48,6 +48,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.opentracing import trace +from synapse.storage.databases.main.relations import BundledAggregations from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder @@ -526,7 +527,7 @@ class SyncRestServlet(RestServlet): def serialize( events: Iterable[EventBase], - aggregations: Optional[Dict[str, Dict[str, Any]]] = None, + aggregations: Optional[Dict[str, BundledAggregations]] = None, ) -> List[JsonDict]: return self._event_serializer.serialize_events( events, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 2cb5d06c13..a9a5dd5f03 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,17 +13,7 @@ # limitations under the License. import logging -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - List, - Optional, - Tuple, - Union, - cast, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast import attr from frozendict import frozendict @@ -43,6 +33,7 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -51,6 +42,30 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _ThreadAggregation: + latest_event: EventBase + count: int + current_user_participated: bool + + +@attr.s(slots=True, auto_attribs=True) +class BundledAggregations: + """ + The bundled aggregations for an event. + + Some values require additional processing during serialization. + """ + + annotations: Optional[JsonDict] = None + references: Optional[JsonDict] = None + replace: Optional[EventBase] = None + thread: Optional[_ThreadAggregation] = None + + def __bool__(self) -> bool: + return bool(self.annotations or self.references or self.replace or self.thread) + + class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -585,7 +600,7 @@ class RelationsWorkerStore(SQLBaseStore): async def _get_bundled_aggregation_for_event( self, event: EventBase, user_id: str - ) -> Optional[Dict[str, Any]]: + ) -> Optional[BundledAggregations]: """Generate bundled aggregations for an event. Note that this does not use a cache, but depends on cached methods. @@ -616,24 +631,24 @@ class RelationsWorkerStore(SQLBaseStore): # The bundled aggregations to include, a mapping of relation type to a # type-specific value. Some types include the direct return type here # while others need more processing during serialization. - aggregations: Dict[str, Any] = {} + aggregations = BundledAggregations() annotations = await self.get_aggregation_groups_for_event(event_id, room_id) if annotations.chunk: - aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + aggregations.annotations = annotations.to_dict() references = await self.get_relations_for_event( event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: - aggregations[RelationTypes.REFERENCE] = references.to_dict() + aggregations.references = references.to_dict() edit = None if event.type == EventTypes.Message: edit = await self.get_applicable_edit(event_id, room_id) if edit: - aggregations[RelationTypes.REPLACE] = edit + aggregations.replace = edit # If this event is the start of a thread, include a summary of the replies. if self._msc3440_enabled: @@ -644,11 +659,11 @@ class RelationsWorkerStore(SQLBaseStore): event_id, room_id, user_id ) if latest_thread_event: - aggregations[RelationTypes.THREAD] = { - "latest_event": latest_thread_event, - "count": thread_count, - "current_user_participated": participated, - } + aggregations.thread = _ThreadAggregation( + latest_event=latest_thread_event, + count=thread_count, + current_user_participated=participated, + ) # Store the bundled aggregations in the event metadata for later use. return aggregations @@ -657,7 +672,7 @@ class RelationsWorkerStore(SQLBaseStore): self, events: Iterable[EventBase], user_id: str, - ) -> Dict[str, Dict[str, Any]]: + ) -> Dict[str, BundledAggregations]: """Generate bundled aggregations for events. Args: @@ -676,7 +691,7 @@ class RelationsWorkerStore(SQLBaseStore): results = {} for event in events: event_result = await self._get_bundled_aggregation_for_event(event, user_id) - if event_result is not None: + if event_result: results[event.event_id] = event_result return results diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 319464b1fa..a898f847e7 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -81,6 +81,14 @@ class _EventDictReturn: stream_ordering: int +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventsAround: + events_before: List[EventBase] + events_after: List[EventBase] + start: RoomStreamToken + end: RoomStreamToken + + def generate_pagination_where_clause( direction: str, column_names: Tuple[str, str], @@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): before_limit: int, after_limit: int, event_filter: Optional[Filter] = None, - ) -> dict: + ) -> _EventsAround: """Retrieve events and pagination tokens around a given event in a room. """ @@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): list(results["after"]["event_ids"]), get_prev_content=True ) - return { - "events_before": events_before, - "events_after": events_after, - "start": results["before"]["token"], - "end": results["after"]["token"], - } + return _EventsAround( + events_before=events_before, + events_after=events_after, + start=results["before"]["token"], + end=results["after"]["token"], + ) def _get_events_around_txn( self, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index c9b220e73d..96ae7790bb 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -577,7 +577,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) - self._find_event_in_chunk(room_timeline["events"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) def test_aggregation_get_event_for_annotation(self): """Test that annotations do not get bundled aggregations included -- cgit 1.5.1 From 2d3bd9aa670eedd299cc03093459929adec41918 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 26 Jan 2022 14:21:13 +0000 Subject: Add a module callback to set username at registration (#11790) This is in the context of mainlining the Tchap fork of Synapse. Currently in Tchap usernames are derived from the user's email address (extracted from the UIA results, more specifically the m.login.email.identity step). This change also exports the check_username method from the registration handler as part of the module API, so that a module can check if the username it's trying to generate is correct and doesn't conflict with an existing one, and fallback gracefully if not. Co-authored-by: David Robertson --- changelog.d/11790.feature | 1 + docs/modules/password_auth_provider_callbacks.md | 62 +++++++++++++++++++ synapse/handlers/auth.py | 58 +++++++++++++++++ synapse/module_api/__init__.py | 22 +++++++ synapse/rest/client/register.py | 12 +++- tests/handlers/test_password_providers.py | 79 +++++++++++++++++++++++- 6 files changed, 231 insertions(+), 3 deletions(-) create mode 100644 changelog.d/11790.feature (limited to 'synapse/rest/client') diff --git a/changelog.d/11790.feature b/changelog.d/11790.feature new file mode 100644 index 0000000000..4a5cc8ec37 --- /dev/null +++ b/changelog.d/11790.feature @@ -0,0 +1 @@ +Add a module callback to set username at registration. diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md index e53abf6409..ec8324d292 100644 --- a/docs/modules/password_auth_provider_callbacks.md +++ b/docs/modules/password_auth_provider_callbacks.md @@ -105,6 +105,68 @@ device ID), and the (now deactivated) access token. If multiple modules implement this callback, Synapse runs them all in order. +### `get_username_for_registration` + +_First introduced in Synapse v1.52.0_ + +```python +async def get_username_for_registration( + uia_results: Dict[str, Any], + params: Dict[str, Any], +) -> Optional[str] +``` + +Called when registering a new user. The module can return a username to set for the user +being registered by returning it as a string, or `None` if it doesn't wish to force a +username for this user. If a username is returned, it will be used as the local part of a +user's full Matrix ID (e.g. it's `alice` in `@alice:example.com`). + +This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api) +has been completed by the user. It is not called when registering a user via SSO. It is +passed two dictionaries, which include the information that the user has provided during +the registration process. + +The first dictionary contains the results of the [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api) +flow followed by the user. Its keys are the identifiers of every step involved in the flow, +associated with either a boolean value indicating whether the step was correctly completed, +or additional information (e.g. email address, phone number...). A list of most existing +identifiers can be found in the [Matrix specification](https://spec.matrix.org/v1.1/client-server-api/#authentication-types). +Here's an example featuring all currently supported keys: + +```python +{ + "m.login.dummy": True, # Dummy authentication + "m.login.terms": True, # User has accepted the terms of service for the homeserver + "m.login.recaptcha": True, # User has completed the recaptcha challenge + "m.login.email.identity": { # User has provided and verified an email address + "medium": "email", + "address": "alice@example.com", + "validated_at": 1642701357084, + }, + "m.login.msisdn": { # User has provided and verified a phone number + "medium": "msisdn", + "address": "33123456789", + "validated_at": 1642701357084, + }, + "org.matrix.msc3231.login.registration_token": "sometoken", # User has registered through the flow described in MSC3231 +} +``` + +The second dictionary contains the parameters provided by the user's client in the request +to `/_matrix/client/v3/register`. See the [Matrix specification](https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3register) +for a complete list of these parameters. + +If the module cannot, or does not wish to, generate a username for this user, it must +return `None`. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `None`, Synapse falls through to the next one. The value of the first +callback that does not return `None` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. If every callback return `None`, +the username provided by the user is used, if any (otherwise one is automatically +generated). + + ## Example The example module below implements authentication checkers for two different login types: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index bd1a322563..e32c93e234 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -2060,6 +2060,10 @@ CHECK_AUTH_CALLBACK = Callable[ Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] ], ] +GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[ + [JsonDict, JsonDict], + Awaitable[Optional[str]], +] class PasswordAuthProvider: @@ -2072,6 +2076,9 @@ class PasswordAuthProvider: # lists of callbacks self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = [] self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = [] + self.get_username_for_registration_callbacks: List[ + GET_USERNAME_FOR_REGISTRATION_CALLBACK + ] = [] # Mapping from login type to login parameters self._supported_login_types: Dict[str, Iterable[str]] = {} @@ -2086,6 +2093,9 @@ class PasswordAuthProvider: auth_checkers: Optional[ Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] ] = None, + get_username_for_registration: Optional[ + GET_USERNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: # Register check_3pid_auth callback if check_3pid_auth is not None: @@ -2130,6 +2140,11 @@ class PasswordAuthProvider: # Add the new method to the list of auth_checker_callbacks for this login type self.auth_checker_callbacks.setdefault(login_type, []).append(callback) + if get_username_for_registration is not None: + self.get_username_for_registration_callbacks.append( + get_username_for_registration, + ) + def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: """Get the login types supported by this password provider @@ -2285,3 +2300,46 @@ class PasswordAuthProvider: except Exception as e: logger.warning("Failed to run module API callback %s: %s", callback, e) continue + + async def get_username_for_registration( + self, + uia_results: JsonDict, + params: JsonDict, + ) -> Optional[str]: + """Defines the username to use when registering the user, using the credentials + and parameters provided during the UIA flow. + + Stops at the first callback that returns a string. + + Args: + uia_results: The credentials provided during the UIA flow. + params: The parameters provided by the registration request. + + Returns: + The localpart to use when registering this user, or None if no module + returned a localpart. + """ + for callback in self.get_username_for_registration_callbacks: + try: + res = await callback(uia_results, params) + + if isinstance(res, str): + return res + elif res is not None: + # mypy complains that this line is unreachable because it assumes the + # data returned by the module fits the expected type. We just want + # to make sure this is the case. + logger.warning( # type: ignore[unreachable] + "Ignoring non-string value returned by" + " get_username_for_registration callback %s: %s", + callback, + res, + ) + except Exception as e: + logger.error( + "Module raised an exception in get_username_for_registration: %s", + e, + ) + raise SynapseError(code=500, msg="Internal Server Error") + + return None diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 662e60bc33..788b2e47d5 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -71,6 +71,7 @@ from synapse.handlers.account_validity import ( from synapse.handlers.auth import ( CHECK_3PID_AUTH_CALLBACK, CHECK_AUTH_CALLBACK, + GET_USERNAME_FOR_REGISTRATION_CALLBACK, ON_LOGGED_OUT_CALLBACK, AuthHandler, ) @@ -177,6 +178,7 @@ class ModuleApi: self._presence_stream = hs.get_event_sources().sources.presence self._state = hs.get_state_handler() self._clock: Clock = hs.get_clock() + self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() self.custom_template_dir = hs.config.server.custom_template_directory @@ -310,6 +312,9 @@ class ModuleApi: auth_checkers: Optional[ Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] ] = None, + get_username_for_registration: Optional[ + GET_USERNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: """Registers callbacks for password auth provider capabilities. @@ -319,6 +324,7 @@ class ModuleApi: check_3pid_auth=check_3pid_auth, on_logged_out=on_logged_out, auth_checkers=auth_checkers, + get_username_for_registration=get_username_for_registration, ) def register_background_update_controller_callbacks( @@ -1202,6 +1208,22 @@ class ModuleApi: """ return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs) + async def check_username(self, username: str) -> None: + """Checks if the provided username uses the grammar defined in the Matrix + specification, and is already being used by an existing user. + + Added in Synapse v1.52.0. + + Args: + username: The username to check. This is the local part of the user's full + Matrix user ID, i.e. it's "alice" if the full user ID is "@alice:foo.com". + + Raises: + SynapseError with the errcode "M_USER_IN_USE" if the username is already in + use. + """ + await self._registration_handler.check_username(username) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index c59dae7c03..e3492f9f93 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -425,6 +425,7 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() + self.password_auth_provider = hs.get_password_auth_provider() self._registration_enabled = self.hs.config.registration.enable_registration self._refresh_tokens_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None @@ -638,7 +639,16 @@ class RegisterRestServlet(RestServlet): if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - desired_username = params.get("username", None) + desired_username = await ( + self.password_auth_provider.get_username_for_registration( + auth_result, + params, + ) + ) + + if desired_username is None: + desired_username = params.get("username", None) + guest_access_token = params.get("guest_access_token", None) if desired_username is not None: diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 2add72b28a..94809cb8be 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -20,10 +20,11 @@ from unittest.mock import Mock from twisted.internet import defer import synapse +from synapse.api.constants import LoginType from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.module_api import ModuleApi -from synapse.rest.client import devices, login, logout -from synapse.types import JsonDict +from synapse.rest.client import devices, login, logout, register +from synapse.types import JsonDict, UserID from tests import unittest from tests.server import FakeChannel @@ -156,6 +157,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): login.register_servlets, devices.register_servlets, logout.register_servlets, + register.register_servlets, ] def setUp(self): @@ -745,6 +747,79 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): on_logged_out.assert_called_once() self.assertTrue(self.called) + def test_username(self): + """Tests that the get_username_for_registration callback can define the username + of a user when registering. + """ + self._setup_get_username_for_registration() + + username = "rin" + channel = self.make_request( + "POST", + "/register", + { + "username": username, + "password": "bar", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(channel.code, 200) + + # Our callback takes the username and appends "-foo" to it, check that's what we + # have. + mxid = channel.json_body["user_id"] + self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") + + def test_username_uia(self): + """Tests that the get_username_for_registration callback is only called at the + end of the UIA flow. + """ + m = self._setup_get_username_for_registration() + + # Initiate the UIA flow. + username = "rin" + channel = self.make_request( + "POST", + "register", + {"username": username, "type": "m.login.password", "password": "bar"}, + ) + self.assertEqual(channel.code, 401) + self.assertIn("session", channel.json_body) + + # Check that the callback hasn't been called yet. + m.assert_not_called() + + # Finish the UIA flow. + session = channel.json_body["session"] + channel = self.make_request( + "POST", + "register", + {"auth": {"session": session, "type": LoginType.DUMMY}}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + mxid = channel.json_body["user_id"] + self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") + + # Check that the callback has been called. + m.assert_called_once() + + def _setup_get_username_for_registration(self) -> Mock: + """Registers a get_username_for_registration callback that appends "-foo" to the + username the client is trying to register. + """ + + async def get_username_for_registration(uia_results, params): + self.assertIn(LoginType.DUMMY, uia_results) + username = params["username"] + return username + "-foo" + + m = Mock(side_effect=get_username_for_registration) + + password_auth_provider = self.hs.get_password_auth_provider() + password_auth_provider.get_username_for_registration_callbacks.append(m) + + return m + def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) -- cgit 1.5.1