From 94b620a5edd6b5bc55c8aad6e00a11cc6bf210fa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 29 Sep 2021 06:44:15 -0400 Subject: Use direct references for configuration variables (part 6). (#10916) --- synapse/handlers/room_member.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'synapse/handlers/room_member.py') diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 1a56c82fbd..02103f6c9a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -90,7 +90,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.third_party_event_rules = hs.get_third_party_event_rules() self._server_notices_mxid = self.config.servernotices.server_notices_mxid self._enable_lookup = hs.config.enable_3pid_lookup - self.allow_per_room_profiles = self.config.allow_per_room_profiles + self.allow_per_room_profiles = self.config.server.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( store=self.store, @@ -617,7 +617,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: - if self.config.block_non_admin_invites: + if self.config.server.block_non_admin_invites: logger.info( "Blocking invite: user is not admin and non-admin " "invites disabled" @@ -1222,7 +1222,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): Raises: ShadowBanError if the requester has been shadow-banned. """ - if self.config.block_non_admin_invites: + if self.config.server.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: raise SynapseError( @@ -1420,7 +1420,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): Returns: bool of whether the complexity is too great, or None if unable to be fetched """ - max_complexity = self.hs.config.limit_remote_rooms.complexity + max_complexity = self.hs.config.server.limit_remote_rooms.complexity complexity = await self.federation_handler.get_room_complexity( remote_room_hosts, room_id ) @@ -1436,7 +1436,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): Args: room_id: The room ID to check for complexity. """ - max_complexity = self.hs.config.limit_remote_rooms.complexity + max_complexity = self.hs.config.server.limit_remote_rooms.complexity complexity = await self.store.get_room_complexity(room_id) return complexity["v1"] > max_complexity @@ -1472,7 +1472,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): if too_complex is True: raise SynapseError( code=400, - msg=self.hs.config.limit_remote_rooms.complexity_error, + msg=self.hs.config.server.limit_remote_rooms.complexity_error, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) @@ -1507,7 +1507,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) raise SynapseError( code=400, - msg=self.hs.config.limit_remote_rooms.complexity_error, + msg=self.hs.config.server.limit_remote_rooms.complexity_error, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) -- cgit 1.5.1 From d1bf5f7c9d669fcf60aadc2c6527447adef2c43c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Sep 2021 11:13:59 -0400 Subject: Strip "join_authorised_via_users_server" from join events which do not need it. (#10933) This fixes a "Event not signed by authorising server" error when transition room member from join -> join, e.g. when updating a display name or avatar URL for restricted rooms. --- changelog.d/10933.bugfix | 1 + synapse/api/constants.py | 3 +++ synapse/event_auth.py | 12 +++++++----- synapse/events/utils.py | 2 +- synapse/federation/federation_base.py | 6 +++--- synapse/federation/federation_client.py | 6 +++--- synapse/federation/federation_server.py | 6 +++--- synapse/handlers/federation.py | 9 +++++++-- synapse/handlers/room_member.py | 10 +++++++++- tests/events/test_utils.py | 7 ++++--- tests/test_event_auth.py | 9 +++++---- 11 files changed, 46 insertions(+), 25 deletions(-) create mode 100644 changelog.d/10933.bugfix (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/10933.bugfix b/changelog.d/10933.bugfix new file mode 100644 index 0000000000..e0694fea22 --- /dev/null +++ b/changelog.d/10933.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.40.0 where changing a user's display name or avatar in a restricted room would cause an authentication error. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 39fd9954d5..a31f037748 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -217,6 +217,9 @@ class EventContentFields: # For "marker" events MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion" + # The authorising user for joining a restricted room. + AUTHORISING_USER = "join_authorised_via_users_server" + class RoomTypes: """Understood values of the room_type field of m.room.create events.""" diff --git a/synapse/event_auth.py b/synapse/event_auth.py index eef354de6e..7a1adc2750 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -102,11 +102,11 @@ def validate_event_for_room_version( room_version_obj.msc3083_join_rules and event.type == EventTypes.Member and event.membership == Membership.JOIN - and "join_authorised_via_users_server" in event.content + and EventContentFields.AUTHORISING_USER in event.content ) if is_invite_via_allow_rule: authoriser_domain = get_domain_from_id( - event.content["join_authorised_via_users_server"] + event.content[EventContentFields.AUTHORISING_USER] ) if not event.signatures.get(authoriser_domain): raise AuthError(403, "Event not signed by authorising server") @@ -413,7 +413,9 @@ def _is_membership_change_allowed( # Note that if the caller is in the room or invited, then they do # not need to meet the allow rules. if not caller_in_room and not caller_invited: - authorising_user = event.content.get("join_authorised_via_users_server") + authorising_user = event.content.get( + EventContentFields.AUTHORISING_USER + ) if authorising_user is None: raise AuthError(403, "Join event is missing authorising user.") @@ -868,10 +870,10 @@ def auth_types_for_event( auth_types.add(key) if room_version.msc3083_join_rules and membership == Membership.JOIN: - if "join_authorised_via_users_server" in event.content: + if EventContentFields.AUTHORISING_USER in event.content: key = ( EventTypes.Member, - event.content["join_authorised_via_users_server"], + event.content[EventContentFields.AUTHORISING_USER], ) auth_types.add(key) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a13fb0148f..520edbbf61 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -105,7 +105,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: if event_type == EventTypes.Member: add_fields("membership") if room_version.msc3375_redaction_rules: - add_fields("join_authorised_via_users_server") + add_fields(EventContentFields.AUTHORISING_USER) elif event_type == EventTypes.Create: # MSC2176 rules state that create events cannot be redacted. if room_version.msc2176_redaction_rules: diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 024e440ff4..0cd424e12a 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -15,7 +15,7 @@ import logging from collections import namedtuple -from synapse.api.constants import MAX_DEPTH, EventTypes, Membership +from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.crypto.event_signing import check_event_content_hash @@ -184,10 +184,10 @@ async def _check_sigs_on_pdu( room_version.msc3083_join_rules and pdu.type == EventTypes.Member and pdu.membership == Membership.JOIN - and "join_authorised_via_users_server" in pdu.content + and EventContentFields.AUTHORISING_USER in pdu.content ): authorising_server = get_domain_from_id( - pdu.content["join_authorised_via_users_server"] + pdu.content[EventContentFields.AUTHORISING_USER] ) try: await keyring.verify_event_for_server( diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 584836c04a..2ab4dec88f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -37,7 +37,7 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import ( CodeMessageException, Codes, @@ -875,9 +875,9 @@ class FederationClient(FederationBase): # If the join is being authorised via allow rules, we need to send # the /send_join back to the same server that was originally used # with /make_join. - if "join_authorised_via_users_server" in pdu.content: + if EventContentFields.AUTHORISING_USER in pdu.content: destinations = [ - get_domain_from_id(pdu.content["join_authorised_via_users_server"]) + get_domain_from_id(pdu.content[EventContentFields.AUTHORISING_USER]) ] return await self._try_destination_list( diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 83f11d6b88..d8c0b86f23 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -34,7 +34,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EduTypes, EventTypes, Membership +from synapse.api.constants import EduTypes, EventContentFields, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -765,11 +765,11 @@ class FederationServer(FederationBase): if ( room_version.msc3083_join_rules and event.membership == Membership.JOIN - and "join_authorised_via_users_server" in event.content + and EventContentFields.AUTHORISING_USER in event.content ): # We can only authorise our own users. authorising_server = get_domain_from_id( - event.content["join_authorised_via_users_server"] + event.content[EventContentFields.AUTHORISING_USER] ) if authorising_server != self.server_name: raise SynapseError( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0a10a5c28a..043ca4a224 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -27,7 +27,12 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, RejectedReason +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RejectedReason, +) from synapse.api.errors import ( AuthError, CodeMessageException, @@ -716,7 +721,7 @@ class FederationHandler(BaseHandler): if include_auth_user_id: event_content[ - "join_authorised_via_users_server" + EventContentFields.AUTHORISING_USER ] = await self._event_auth_handler.get_user_which_could_invite( room_id, state_ids, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 02103f6c9a..29b3e41cc9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -573,6 +573,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): errcode=Codes.BAD_JSON, ) + # The event content should *not* include the authorising user as + # it won't be properly signed. Strip it out since it might come + # back from a client updating a display name / avatar. + # + # This only applies to restricted rooms, but there should be no reason + # for a client to include it. Unconditionally remove it. + content.pop(EventContentFields.AUTHORISING_USER, None) + effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" @@ -939,7 +947,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # be included in the event content in order to efficiently validate # the event. content[ - "join_authorised_via_users_server" + EventContentFields.AUTHORISING_USER ] = await self.event_auth_handler.get_user_which_could_invite( room_id, current_state_ids, diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 5446fda5e7..1dea09e480 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.utils import ( @@ -352,7 +353,7 @@ class PruneEventTestCase(unittest.TestCase): "event_id": "$test:domain", "content": { "membership": "join", - "join_authorised_via_users_server": "@user:domain", + EventContentFields.AUTHORISING_USER: "@user:domain", "other_key": "stripped", }, }, @@ -372,7 +373,7 @@ class PruneEventTestCase(unittest.TestCase): "type": "m.room.member", "content": { "membership": "join", - "join_authorised_via_users_server": "@user:domain", + EventContentFields.AUTHORISING_USER: "@user:domain", "other_key": "stripped", }, }, @@ -380,7 +381,7 @@ class PruneEventTestCase(unittest.TestCase): "type": "m.room.member", "content": { "membership": "join", - "join_authorised_via_users_server": "@user:domain", + EventContentFields.AUTHORISING_USER: "@user:domain", }, "signatures": {}, "unsigned": {}, diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index e7a7d00883..cf407c51cf 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -16,6 +16,7 @@ import unittest from typing import Optional from synapse import event_auth +from synapse.api.constants import EventContentFields from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict @@ -353,7 +354,7 @@ class EventAuthTestCase(unittest.TestCase): authorised_join_event = _join_event( pleb, additional_content={ - "join_authorised_via_users_server": "@creator:example.com" + EventContentFields.AUTHORISING_USER: "@creator:example.com" }, ) event_auth.check_auth_rules_for_event( @@ -376,7 +377,7 @@ class EventAuthTestCase(unittest.TestCase): _join_event( pleb, additional_content={ - "join_authorised_via_users_server": "@inviter:foo.test" + EventContentFields.AUTHORISING_USER: "@inviter:foo.test" }, ), pl_auth_events, @@ -401,7 +402,7 @@ class EventAuthTestCase(unittest.TestCase): _join_event( pleb, additional_content={ - "join_authorised_via_users_server": "@other:example.com" + EventContentFields.AUTHORISING_USER: "@other:example.com" }, ), auth_events, @@ -417,7 +418,7 @@ class EventAuthTestCase(unittest.TestCase): "join", sender=creator, additional_content={ - "join_authorised_via_users_server": "@inviter:foo.test" + EventContentFields.AUTHORISING_USER: "@inviter:foo.test" }, ), auth_events, -- cgit 1.5.1 From a0f48ee89d88fd7b6da8023dbba607a69073152e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 4 Oct 2021 07:18:54 -0400 Subject: Use direct references for configuration variables (part 7). (#10959) --- changelog.d/10959.misc | 1 + synapse/handlers/auth.py | 2 +- synapse/handlers/identity.py | 13 ++++++++++--- synapse/handlers/profile.py | 4 ++-- synapse/handlers/register.py | 9 ++++++--- synapse/handlers/room_member.py | 2 +- synapse/handlers/ui_auth/checkers.py | 14 ++++++++------ synapse/rest/admin/users.py | 4 ++-- synapse/rest/client/account.py | 22 +++++++++++----------- synapse/rest/client/auth.py | 6 ++++-- synapse/rest/client/capabilities.py | 6 +++--- synapse/rest/client/login.py | 6 +++--- synapse/rest/client/register.py | 26 +++++++++++++------------- synapse/rest/well_known.py | 4 ++-- synapse/storage/databases/main/registration.py | 2 +- synapse/util/threepids.py | 4 ++-- tests/config/test_load.py | 6 +++--- tests/handlers/test_profile.py | 4 ++-- tests/rest/admin/test_user.py | 4 ++-- tests/rest/client/test_account.py | 4 ++-- tests/rest/client/test_identity.py | 2 +- tests/rest/client/test_register.py | 4 ++-- tests/unittest.py | 2 +- 23 files changed, 83 insertions(+), 68 deletions(-) create mode 100644 changelog.d/10959.misc (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/10959.misc b/changelog.d/10959.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10959.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a8c717efd5..2d0f3d566c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -198,7 +198,7 @@ class AuthHandler(BaseHandler): if inst.is_enabled(): self.checkers[inst.AUTH_TYPE] = inst # type: ignore - self.bcrypt_rounds = hs.config.bcrypt_rounds + self.bcrypt_rounds = hs.config.registration.bcrypt_rounds # we can't use hs.get_module_api() here, because to do so will create an # import loop. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index a0640fcac0..c881475c25 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -573,9 +573,15 @@ class IdentityHandler(BaseHandler): # Try to validate as email if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + # Remote emails will only be used if a valid identity server is provided. + assert ( + self.hs.config.registration.account_threepid_delegate_email is not None + ) + # Ask our delegated email identity server validation_session = await self.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds + self.hs.config.registration.account_threepid_delegate_email, + threepid_creds, ) elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: # Get a validated session matching these details @@ -587,10 +593,11 @@ class IdentityHandler(BaseHandler): return validation_session # Try to validate as msisdn - if self.hs.config.account_threepid_delegate_msisdn: + if self.hs.config.registration.account_threepid_delegate_msisdn: # Ask our delegated msisdn identity server validation_session = await self.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds + self.hs.config.registration.account_threepid_delegate_msisdn, + threepid_creds, ) return validation_session diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 425c0d4973..2e19706c69 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -178,7 +178,7 @@ class ProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") - if not by_admin and not self.hs.config.enable_set_displayname: + if not by_admin and not self.hs.config.registration.enable_set_displayname: profile = await self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( @@ -268,7 +268,7 @@ class ProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") - if not by_admin and not self.hs.config.enable_set_avatar_url: + if not by_admin and not self.hs.config.registration.enable_set_avatar_url: profile = await self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: raise SynapseError( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cb4eb0720b..441af7a848 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -116,8 +116,8 @@ class RegistrationHandler(BaseHandler): self._register_device_client = self.register_device_inner self.pusher_pool = hs.get_pusherpool() - self.session_lifetime = hs.config.session_lifetime - self.access_token_lifetime = hs.config.access_token_lifetime + self.session_lifetime = hs.config.registration.session_lifetime + self.access_token_lifetime = hs.config.registration.access_token_lifetime init_counters_for_auth_provider("") @@ -343,7 +343,10 @@ class RegistrationHandler(BaseHandler): # If the user does not need to consent at registration, auto-join any # configured rooms. if not self.hs.config.consent.user_consent_at_registration: - if not self.hs.config.auto_join_rooms_for_guests and make_guest: + if ( + not self.hs.config.registration.auto_join_rooms_for_guests + and make_guest + ): logger.info( "Skipping auto-join for %s because auto-join for guests is disabled", user_id, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 29b3e41cc9..c8fb24a20c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -89,7 +89,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.spam_checker = hs.get_spam_checker() self.third_party_event_rules = hs.get_third_party_event_rules() self._server_notices_mxid = self.config.servernotices.server_notices_mxid - self._enable_lookup = hs.config.enable_3pid_lookup + self._enable_lookup = hs.config.registration.enable_3pid_lookup self.allow_per_room_profiles = self.config.server.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 8f5d465fa1..184730ebe8 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -153,21 +153,23 @@ class _BaseThreepidAuthChecker: # msisdns are currently always ThreepidBehaviour.REMOTE if medium == "msisdn": - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: raise SynapseError( 400, "Phone number verification is not enabled on this homeserver" ) threepid = await identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds + self.hs.config.registration.account_threepid_delegate_msisdn, + threepid_creds, ) elif medium == "email": if ( self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE ): - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email threepid = await identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds + self.hs.config.registration.account_threepid_delegate_email, + threepid_creds, ) elif ( self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL @@ -240,7 +242,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): _BaseThreepidAuthChecker.__init__(self, hs) def is_enabled(self) -> bool: - return bool(self.hs.config.account_threepid_delegate_msisdn) + return bool(self.hs.config.registration.account_threepid_delegate_msisdn) async def check_auth(self, authdict: dict, clientip: str) -> Any: return await self._check_threepid("msisdn", authdict) @@ -252,7 +254,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs - self._enabled = bool(hs.config.registration_requires_token) + self._enabled = bool(hs.config.registration.registration_requires_token) self.store = hs.get_datastore() def is_enabled(self) -> bool: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 46bfec4623..f20aa65301 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -442,7 +442,7 @@ class UserRegisterServlet(RestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._clear_old_nonces() - if not self.hs.config.registration_shared_secret: + if not self.hs.config.registration.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") body = parse_json_object_from_request(request) @@ -498,7 +498,7 @@ class UserRegisterServlet(RestServlet): got_mac = body["mac"] want_mac_builder = hmac.new( - key=self.hs.config.registration_shared_secret.encode(), + key=self.hs.config.registration.registration_shared_secret.encode(), digestmod=hashlib.sha1, ) want_mac_builder.update(nonce.encode("utf8")) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index fff133ef10..6b272658fc 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -130,11 +130,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -414,11 +414,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -496,7 +496,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" @@ -507,7 +507,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): ) ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, + self.hs.config.registration.account_threepid_delegate_msisdn, country, phone_number, client_secret, @@ -604,7 +604,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: - if not self.config.account_threepid_delegate_msisdn: + if not self.config.registration.account_threepid_delegate_msisdn: raise SynapseError( 400, "This homeserver is not validating phone numbers. Use an identity server " @@ -617,7 +617,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): # Proxy submit_token request to msisdn threepid delegate response = await self.identity_handler.proxy_msisdn_submit_token( - self.config.account_threepid_delegate_msisdn, + self.config.registration.account_threepid_delegate_msisdn, body["client_secret"], body["sid"], body["token"], @@ -644,7 +644,7 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) @@ -693,7 +693,7 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) @@ -801,7 +801,7 @@ class ThreepidDeleteRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 282861fae2..c9ad35a3ad 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -49,8 +49,10 @@ class AuthRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.captcha.recaptcha_template self.terms_template = hs.config.terms_template - self.registration_token_template = hs.config.registration_token_template - self.success_template = hs.config.fallback_success_template + self.registration_token_template = ( + hs.config.registration.registration_token_template + ) + self.success_template = hs.config.registration.fallback_success_template async def on_GET(self, request: SynapseRequest, stagetype: str) -> None: session = parse_string(request, "session") diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index d6b6256413..2a3e24ae7e 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -64,13 +64,13 @@ class CapabilitiesRestServlet(RestServlet): if self.config.experimental.msc3283_enabled: response["capabilities"]["org.matrix.msc3283.set_displayname"] = { - "enabled": self.config.enable_set_displayname + "enabled": self.config.registration.enable_set_displayname } response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = { - "enabled": self.config.enable_set_avatar_url + "enabled": self.config.registration.enable_set_avatar_url } response["capabilities"]["org.matrix.msc3283.3pid_changes"] = { - "enabled": self.config.enable_3pid_changes + "enabled": self.config.registration.enable_3pid_changes } return 200, response diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index fa5c173f4b..d49a647b03 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -79,7 +79,7 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._msc2918_enabled = hs.config.access_token_lifetime is not None + self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None self.auth = hs.get_auth() @@ -447,7 +447,7 @@ class RefreshTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() self._clock = hs.get_clock() - self.access_token_lifetime = hs.config.access_token_lifetime + self.access_token_lifetime = hs.config.registration.access_token_lifetime async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) @@ -556,7 +556,7 @@ class CasTicketServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) - if hs.config.access_token_lifetime is not None: + if hs.config.registration.access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas.cas_enabled: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index a6eb6f6410..bf3cb34146 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -140,11 +140,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -221,7 +221,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): 400, "Phone number is already in use", Codes.THREEPID_IN_USE ) - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" @@ -231,7 +231,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, + self.hs.config.registration.account_threepid_delegate_msisdn, country, phone_number, client_secret, @@ -341,7 +341,7 @@ class UsernameAvailabilityRestServlet(RestServlet): ) async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_registration: + if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) @@ -391,7 +391,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) - if not self.hs.config.enable_registration: + if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) @@ -419,8 +419,8 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() - self._registration_enabled = self.hs.config.enable_registration - self._msc2918_enabled = hs.config.access_token_lifetime is not None + self._registration_enabled = self.hs.config.registration.enable_registration + self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -800,7 +800,7 @@ class RegisterRestServlet(RestServlet): async def _do_guest_registration( self, params: JsonDict, address: Optional[str] = None ) -> Tuple[int, JsonDict]: - if not self.hs.config.allow_guest_access: + if not self.hs.config.registration.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id = await self.registration_handler.register_user( make_guest=True, address=address @@ -849,13 +849,13 @@ def _calculate_registration_flows( """ # FIXME: need a better error than "no auth flow found" for scenarios # where we required 3PID for registration but the user didn't give one - require_email = "email" in config.registrations_require_3pid - require_msisdn = "msisdn" in config.registrations_require_3pid + require_email = "email" in config.registration.registrations_require_3pid + require_msisdn = "msisdn" in config.registration.registrations_require_3pid show_msisdn = True show_email = True - if config.disable_msisdn_registration: + if config.registration.disable_msisdn_registration: show_msisdn = False require_msisdn = False @@ -909,7 +909,7 @@ def _calculate_registration_flows( flow.insert(0, LoginType.RECAPTCHA) # Prepend registration token to all flows if we're requiring a token - if config.registration_requires_token: + if config.registration.registration_requires_token: for flow in flows: flow.insert(0, LoginType.REGISTRATION_TOKEN) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index c80a3a99aa..7ac01faab4 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -39,9 +39,9 @@ class WellKnownBuilder: result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}} - if self._config.default_identity_server: + if self._config.registration.default_identity_server: result["m.identity_server"] = { - "base_url": self._config.default_identity_server + "base_url": self._config.registration.default_identity_server } return result diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7279b0924e..de262fbf5a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1710,7 +1710,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): We do this by grandfathering in existing user threepids assuming that they used one of the server configured trusted identity servers. """ - id_servers = set(self.config.trusted_third_party_id_servers) + id_servers = set(self.config.registration.trusted_third_party_id_servers) def _bg_user_threepids_grandfather_txn(txn): sql = """ diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index baa9190a9a..389adf00f6 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -44,8 +44,8 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: bool: whether the 3PID medium/address is allowed to be added to this HS """ - if hs.config.allowed_local_3pids: - for constraint in hs.config.allowed_local_3pids: + if hs.config.registration.allowed_local_3pids: + for constraint in hs.config.registration.allowed_local_3pids: logger.debug( "Checking 3PID %s (%s) against %s (%s)", address, diff --git a/tests/config/test_load.py b/tests/config/test_load.py index ef6c2beec7..8e49ca26d9 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -84,16 +84,16 @@ class ConfigLoadingTestCase(unittest.TestCase): ) # Check that disable_registration clobbers enable_registration. config = HomeServerConfig.load_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) # Check that either config value is clobbered by the command line. config = HomeServerConfig.load_or_generate_config( "", ["-c", self.file, "--enable-registration"] ) - self.assertTrue(config.enable_registration) + self.assertTrue(config.registration.enable_registration) def test_stats_enabled(self): self.generate_config_and_remove_lines_containing("enable_metrics") diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 57cc3e2646..c153018fd8 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -110,7 +110,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_name_if_disabled(self): - self.hs.config.enable_set_displayname = False + self.hs.config.registration.enable_set_displayname = False # Setting displayname for the first time is allowed self.get_success( @@ -225,7 +225,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_avatar_if_disabled(self): - self.hs.config.enable_set_avatar_url = False + self.hs.config.registration.enable_set_avatar_url = False # Setting displayname for the first time is allowed self.get_success( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index a285d5a7fe..6ed9e42173 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -59,7 +59,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver() - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" self.hs.get_media_repository = Mock() self.hs.get_deactivate_account_handler = Mock() @@ -71,7 +71,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): If there is no shared secret, registration through this method will be prevented. """ - self.hs.config.registration_shared_secret = None + self.hs.config.registration.registration_shared_secret = None channel = self.make_request("POST", self.url, b"{}") diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 2f44547bfb..89d85b0a17 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -664,7 +664,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -734,7 +734,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_delete_email_if_disabled(self): """Test deleting an email from profile when disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False # Add a threepid self.get_success( diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index ca2e8ff8ef..becb4e8dcc 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -37,7 +37,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): return self.hs def test_3pid_lookup_disabled(self): - self.hs.config.enable_3pid_lookup = False + self.hs.config.registration.enable_3pid_lookup = False self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index af135d57e1..66dcfc9f88 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -147,7 +147,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_guest_registration(self): self.hs.config.key.macaroon_secret_key = "test" - self.hs.config.allow_guest_access = True + self.hs.config.registration.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") @@ -156,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self): - self.hs.config.allow_guest_access = False + self.hs.config.registration.allow_guest_access = False channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") diff --git a/tests/unittest.py b/tests/unittest.py index 0807467e39..1f803564f6 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -560,7 +560,7 @@ class HomeserverTestCase(TestCase): Returns: The MXID of the new user. """ - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" # Create the user channel = self.make_request("GET", "/_synapse/admin/v1/register") -- cgit 1.5.1 From 829f2a82b042d944fef3df55faec924502cdf20d Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 6 Oct 2021 16:32:16 +0200 Subject: Add a spamchecker callback to allow or deny room joins (#10910) Co-authored-by: Erik Johnston --- changelog.d/10910.feature | 1 + docs/modules/spam_checker_callbacks.md | 15 +++++ synapse/events/spamcheck.py | 24 ++++++++ synapse/handlers/room.py | 2 + synapse/handlers/room_member.py | 31 ++++++++++ tests/rest/client/test_rooms.py | 101 +++++++++++++++++++++++++++++++++ 6 files changed, 174 insertions(+) create mode 100644 changelog.d/10910.feature (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/10910.feature b/changelog.d/10910.feature new file mode 100644 index 0000000000..aee139f8b6 --- /dev/null +++ b/changelog.d/10910.feature @@ -0,0 +1 @@ +Add a spam checker callback to allow or deny room joins. diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 7920ac5f8f..92376df993 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -19,6 +19,21 @@ either a `bool` to indicate whether the event must be rejected because of spam, to indicate the event must be rejected because of spam and to give a rejection reason to forward to clients. +### `user_may_join_room` + +```python +async def user_may_join_room(user: str, room: str, is_invited: bool) -> bool +``` + +Called when a user is trying to join a room. The module must return a `bool` to indicate +whether the user can join the room. The user is represented by their Matrix user ID (e.g. +`@alice:example.com`) and the room is represented by its Matrix ID (e.g. +`!room:example.com`). The module is also given a boolean to indicate whether the user +currently has a pending invite in the room. + +This callback isn't called if the join is performed by a server administrator, or in the +context of a room creation. + ### `user_may_invite` ```python diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index c389f70b8d..ec8863e397 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -44,6 +44,7 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[ ["synapse.events.EventBase"], Awaitable[Union[bool, str]], ] +USER_MAY_JOIN_ROOM_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]] USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]] USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK = Callable[ @@ -165,6 +166,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): class SpamChecker: def __init__(self): self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = [] + self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = [] self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = [] self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = [] self._user_may_create_room_with_invites_callbacks: List[ @@ -187,6 +189,7 @@ class SpamChecker: def register_callbacks( self, check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None, + user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None, user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, user_may_create_room_with_invites: Optional[ @@ -206,6 +209,9 @@ class SpamChecker: if check_event_for_spam is not None: self._check_event_for_spam_callbacks.append(check_event_for_spam) + if user_may_join_room is not None: + self._user_may_join_room_callbacks.append(user_may_join_room) + if user_may_invite is not None: self._user_may_invite_callbacks.append(user_may_invite) @@ -259,6 +265,24 @@ class SpamChecker: return False + async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool): + """Checks if a given users is allowed to join a room. + Not called when a user creates a room. + + Args: + userid: The ID of the user wanting to join the room + room_id: The ID of the room the user wants to join + is_invited: Whether the user is invited into the room + + Returns: + bool: Whether the user may join the room + """ + for callback in self._user_may_join_room_callbacks: + if await callback(user_id, room_id, is_invited) is False: + return False + + return True + async def user_may_invite( self, inviter_userid: str, invitee_userid: str, room_id: str ) -> bool: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 873e08258e..d40dbd761d 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -860,6 +860,7 @@ class RoomCreationHandler(BaseHandler): "invite", ratelimit=False, content=content, + new_room=True, ) for invite_3pid in invite_3pid_list: @@ -962,6 +963,7 @@ class RoomCreationHandler(BaseHandler): "join", ratelimit=ratelimit, content=creator_join_profile, + new_room=True, ) # We treat the power levels override specially as this needs to be one diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index c8fb24a20c..0b79dbcf8d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -434,6 +434,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed: Optional[dict] = None, ratelimit: bool = True, content: Optional[dict] = None, + new_room: bool = False, require_consent: bool = True, outlier: bool = False, prev_event_ids: Optional[List[str]] = None, @@ -451,6 +452,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed: Information from a 3PID invite. ratelimit: Whether to rate limit the request. content: The content of the created event. + new_room: Whether the membership update is happening in the context of a room + creation. require_consent: Whether consent is required. outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as @@ -485,6 +488,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed=third_party_signed, ratelimit=ratelimit, content=content, + new_room=new_room, require_consent=require_consent, outlier=outlier, prev_event_ids=prev_event_ids, @@ -504,6 +508,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed: Optional[dict] = None, ratelimit: bool = True, content: Optional[dict] = None, + new_room: bool = False, require_consent: bool = True, outlier: bool = False, prev_event_ids: Optional[List[str]] = None, @@ -523,6 +528,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed: ratelimit: content: + new_room: Whether the membership update is happening in the context of a room + creation. require_consent: outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as @@ -726,6 +733,30 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") + # Figure out whether the user is a server admin to determine whether they + # should be able to bypass the spam checker. + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): + # allow the server notices mxid to join rooms + bypass_spam_checker = True + + else: + bypass_spam_checker = await self.auth.is_server_admin(requester.user) + + inviter = await self._get_inviter(target.to_string(), room_id) + if ( + not bypass_spam_checker + # We assume that if the spam checker allowed the user to create + # a room then they're allowed to join it. + and not new_room + and not await self.spam_checker.user_may_join_room( + target.to_string(), room_id, is_invited=inviter is not None + ) + ): + raise SynapseError(403, "Not allowed to join this room") + # Check if a remote join should be performed. remote_join, remote_room_hosts = await self._should_perform_remote_join( target.to_string(), room_id, remote_room_hosts, content, is_host_in_room diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 30bdaa9c27..a41ec6a98f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -784,6 +784,30 @@ class RoomsCreateTestCase(RoomBase): # Check that do_3pid_invite wasn't called this time. self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids)) + def test_spam_checker_may_join_room(self): + """Tests that the user_may_join_room spam checker callback is correctly bypassed + when creating a new room. + """ + + async def user_may_join_room( + mxid: str, + room_id: str, + is_invite: bool, + ) -> bool: + return False + + join_mock = Mock(side_effect=user_may_join_room) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock) + + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + self.assertEquals(channel.code, 200, channel.json_body) + + self.assertEquals(join_mock.call_count, 0) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -975,6 +999,83 @@ class RoomInviteRatelimitTestCase(RoomBase): self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429) +class RoomJoinTestCase(RoomBase): + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.user1 = self.register_user("thomas", "hackme") + self.tok1 = self.login("thomas", "hackme") + + self.user2 = self.register_user("teresa", "hackme") + self.tok2 = self.login("teresa", "hackme") + + self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) + self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) + self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) + + def test_spam_checker_may_join_room(self): + """Tests that the user_may_join_room spam checker callback is correctly called + and blocks room joins when needed. + """ + + # Register a dummy callback. Make it allow all room joins for now. + return_value = True + + async def user_may_join_room( + userid: str, + room_id: str, + is_invited: bool, + ) -> bool: + return return_value + + callback_mock = Mock(side_effect=user_may_join_room) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) + + # Join a first room, without being invited to it. + self.helper.join(self.room1, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room1, + False, + ), + ) + self.assertEquals( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Join a second room, this time with an invite for it. + self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1) + self.helper.join(self.room2, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room2, + True, + ), + ) + self.assertEquals( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Now make the callback deny all room joins, and check that a join actually fails. + return_value = False + self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + + class RoomJoinRatelimitTestCase(RoomBase): user_id = "@sid1:red" -- cgit 1.5.1 From f4b1a9a527273ef71b2f7d970642b7af45462e0f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 6 Oct 2021 10:47:41 -0400 Subject: Require direct references to configuration variables. (#10985) This removes the magic allowing accessing configurable variables directly from the config object. It is now required that a specific configuration class is used (e.g. `config.foo` must be replaced with `config.server.foo`). --- changelog.d/10985.misc | 1 + scripts/synapse_port_db | 4 +- scripts/update_synapse_database | 2 +- synapse/app/_base.py | 2 +- synapse/app/admin_cmd.py | 4 +- synapse/app/homeserver.py | 2 +- synapse/config/_base.py | 64 ++++---------------------- synapse/config/account_validity.py | 2 +- synapse/config/cas.py | 2 +- synapse/config/emailconfig.py | 9 ++-- synapse/config/key.py | 6 ++- synapse/config/oidc.py | 2 +- synapse/config/registration.py | 7 ++- synapse/config/repository.py | 2 +- synapse/config/saml2.py | 2 +- synapse/config/server_notices.py | 4 +- synapse/config/sso.py | 6 ++- synapse/handlers/account_validity.py | 8 +--- synapse/handlers/room_member.py | 7 ++- synapse/replication/tcp/client.py | 2 +- synapse/replication/tcp/handler.py | 7 ++- synapse/rest/client/auth.py | 2 +- synapse/rest/client/push_rule.py | 4 +- synapse/storage/databases/main/push_rule.py | 4 +- synapse/storage/databases/main/registration.py | 4 +- tests/config/test_base.py | 21 +++++---- tests/config/test_cache.py | 50 ++++++++------------ tests/config/test_load.py | 12 +++-- tests/config/test_tls.py | 38 +++++++-------- tests/storage/test_appservice.py | 2 +- tests/storage/test_txn_limit.py | 2 +- 31 files changed, 124 insertions(+), 160 deletions(-) create mode 100644 changelog.d/10985.misc (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/10985.misc b/changelog.d/10985.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10985.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index fa6ac6d93a..a947d9e49e 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -215,7 +215,7 @@ class MockHomeserver: def __init__(self, config): self.clock = Clock(reactor) self.config = config - self.hostname = config.server_name + self.hostname = config.server.server_name self.version_string = "Synapse/" + get_version_string(synapse) def get_clock(self): @@ -583,7 +583,7 @@ class Porter(object): return self.postgres_store = self.build_db_store( - self.hs_config.get_single_database() + self.hs_config.database.get_single_database() ) await self.run_background_updates_on_postgres() diff --git a/scripts/update_synapse_database b/scripts/update_synapse_database index 26b29b0b45..6c088bad93 100755 --- a/scripts/update_synapse_database +++ b/scripts/update_synapse_database @@ -36,7 +36,7 @@ class MockHomeserver(HomeServer): def __init__(self, config, **kwargs): super(MockHomeserver, self).__init__( - config.server_name, reactor=reactor, config=config, **kwargs + config.server.server_name, reactor=reactor, config=config, **kwargs ) self.version_string = "Synapse/" + get_version_string(synapse) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 749bc1deb9..4a204a5823 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -301,7 +301,7 @@ def refresh_certificate(hs): if not hs.config.server.has_tls_listener(): return - hs.config.read_certificate_from_disk() + hs.config.tls.read_certificate_from_disk() hs.tls_server_context_factory = context_factory.ServerContextFactory(hs.config) if hs._listening_services: diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 556bcc124e..13d20af457 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -197,9 +197,9 @@ def start(config_options): # Explicitly disable background processes config.server.update_user_directory = False config.worker.run_background_tasks = False - config.start_pushers = False + config.worker.start_pushers = False config.pusher_shard_config.instances = [] - config.send_federation = False + config.worker.send_federation = False config.federation_shard_config.instances = [] synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2b2d4bbf83..422f03cc04 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -234,7 +234,7 @@ class SynapseHomeServer(HomeServer): ) if name in ["media", "federation", "client"]: - if self.config.media.enable_media_repo: + if self.config.server.enable_media_repo: media_repo = self.get_media_repository_resource() resources.update( {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo} diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 26152b0924..7c4428a138 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -118,21 +118,6 @@ class Config: "synapse", "res/templates" ) - def __getattr__(self, item: str) -> Any: - """ - Try and fetch a configuration option that does not exist on this class. - - This is so that existing configs that rely on `self.value`, where value - is actually from a different config section, continue to work. - """ - if item in ["generate_config_section", "read_config"]: - raise AttributeError(item) - - if self.root is None: - raise AttributeError(item) - else: - return self.root._get_unclassed_config(self.section, item) - @staticmethod def parse_size(value): if isinstance(value, int): @@ -289,7 +274,9 @@ class Config: env.filters.update( { "format_ts": _format_ts_filter, - "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + "mxc_to_http": _create_mxc_to_http_filter( + self.root.server.public_baseurl + ), } ) @@ -311,8 +298,6 @@ class RootConfig: config_classes = [] def __init__(self): - self._configs = OrderedDict() - for config_class in self.config_classes: if config_class.section is None: raise ValueError("%r requires a section name" % (config_class,)) @@ -321,42 +306,7 @@ class RootConfig: conf = config_class(self) except Exception as e: raise Exception("Failed making %s: %r" % (config_class.section, e)) - self._configs[config_class.section] = conf - - def __getattr__(self, item: str) -> Any: - """ - Redirect lookups on this object either to config objects, or values on - config objects, so that `config.tls.blah` works, as well as legacy uses - of things like `config.server.server_name`. It will first look up the config - section name, and then values on those config classes. - """ - if item in self._configs.keys(): - return self._configs[item] - - return self._get_unclassed_config(None, item) - - def _get_unclassed_config(self, asking_section: Optional[str], item: str): - """ - Fetch a config value from one of the instantiated config classes that - has not been fetched directly. - - Args: - asking_section: If this check is coming from a Config child, which - one? This section will not be asked if it has the value. - item: The configuration value key. - - Raises: - AttributeError if no config classes have the config key. The body - will contain what sections were checked. - """ - for key, val in self._configs.items(): - if key == asking_section: - continue - - if item in dir(val): - return getattr(val, item) - - raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),)) + setattr(self, config_class.section, conf) def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]: """ @@ -373,9 +323,11 @@ class RootConfig: """ res = OrderedDict() - for name, config in self._configs.items(): + for config_class in self.config_classes: + config = getattr(self, config_class.section) + if hasattr(config, func_name): - res[name] = getattr(config, func_name)(*args, **kwargs) + res[config_class.section] = getattr(config, func_name)(*args, **kwargs) return res diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index ffaffc4931..b56c2a24df 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -76,7 +76,7 @@ class AccountValidityConfig(Config): ) if self.account_validity_renew_by_email_enabled: - if not self.public_baseurl: + if not self.root.server.public_baseurl: raise ConfigError("Can't send renewal emails without 'public_baseurl'") # Load account validity templates. diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 901f4123e1..9b58ecf3d8 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -37,7 +37,7 @@ class CasConfig(Config): # The public baseurl is required because it is used by the redirect # template. - public_baseurl = self.public_baseurl + public_baseurl = self.root.server.public_baseurl if not public_baseurl: raise ConfigError("cas_config requires a public_baseurl to be set") diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 936abe6178..8ff59aa2f8 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -19,7 +19,6 @@ import email.utils import logging import os from enum import Enum -from typing import Optional import attr @@ -135,7 +134,7 @@ class EmailConfig(Config): # msisdn is currently always remote while Synapse does not support any method of # sending SMS messages ThreepidBehaviour.REMOTE - if self.account_threepid_delegate_email + if self.root.registration.account_threepid_delegate_email else ThreepidBehaviour.LOCAL ) # Prior to Synapse v1.4.0, there was another option that defined whether Synapse would @@ -144,7 +143,7 @@ class EmailConfig(Config): # identity server in the process. self.using_identity_server_from_trusted_list = False if ( - not self.account_threepid_delegate_email + not self.root.registration.account_threepid_delegate_email and config.get("trust_identity_server_for_password_resets", False) is True ): # Use the first entry in self.trusted_third_party_id_servers instead @@ -156,7 +155,7 @@ class EmailConfig(Config): # trusted_third_party_id_servers does not contain a scheme whereas # account_threepid_delegate_email is expected to. Presume https - self.account_threepid_delegate_email: Optional[str] = ( + self.root.registration.account_threepid_delegate_email = ( "https://" + first_trusted_identity_server ) self.using_identity_server_from_trusted_list = True @@ -335,7 +334,7 @@ class EmailConfig(Config): "client_base_url", email_config.get("riot_base_url", None) ) - if self.account_validity_renew_by_email_enabled: + if self.root.account_validity.account_validity_renew_by_email_enabled: expiry_template_html = email_config.get( "expiry_template_html", "notice_expiry.html" ) diff --git a/synapse/config/key.py b/synapse/config/key.py index 94a9063043..015dbb8a67 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -145,11 +145,13 @@ class KeyConfig(Config): # list of TrustedKeyServer objects self.key_servers = list( - _parse_key_servers(key_servers, self.federation_verify_certificates) + _parse_key_servers( + key_servers, self.root.tls.federation_verify_certificates + ) ) self.macaroon_secret_key = config.get( - "macaroon_secret_key", self.registration_shared_secret + "macaroon_secret_key", self.root.registration.registration_shared_secret ) if not self.macaroon_secret_key: diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 7e67fbada1..10f5796330 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -58,7 +58,7 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - public_baseurl = self.public_baseurl + public_baseurl = self.root.server.public_baseurl if public_baseurl is None: raise ConfigError("oidc_config requires a public_baseurl to be set") self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback" diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 7cffdacfa5..a3d2a38c4c 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -45,7 +45,10 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") - if self.account_threepid_delegate_msisdn and not self.public_baseurl: + if ( + self.account_threepid_delegate_msisdn + and not self.root.server.public_baseurl + ): raise ConfigError( "The configuration option `public_baseurl` is required if " "`account_threepid_delegate.msisdn` is set, such that " @@ -85,7 +88,7 @@ class RegistrationConfig(Config): if mxid_localpart: # Convert the localpart to a full mxid. self.auto_join_user_id = UserID( - mxid_localpart, self.server_name + mxid_localpart, self.root.server.server_name ).to_string() if self.autocreate_auto_join_rooms: diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 7481f3bf5f..69906a98d4 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -94,7 +94,7 @@ class ContentRepositoryConfig(Config): # Only enable the media repo if either the media repo is enabled or the # current worker app is the media repo. if ( - self.enable_media_repo is False + self.root.server.enable_media_repo is False and config.get("worker_app") != "synapse.app.media_repository" ): self.can_load_media_repo = False diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 05e983625d..9c51b6a25a 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -199,7 +199,7 @@ class SAML2Config(Config): """ import saml2 - public_baseurl = self.public_baseurl + public_baseurl = self.root.server.public_baseurl if public_baseurl is None: raise ConfigError("saml2_config requires a public_baseurl to be set") diff --git a/synapse/config/server_notices.py b/synapse/config/server_notices.py index 48bf3241b6..bde4e879d9 100644 --- a/synapse/config/server_notices.py +++ b/synapse/config/server_notices.py @@ -73,7 +73,9 @@ class ServerNoticesConfig(Config): return mxid_localpart = c["system_mxid_localpart"] - self.server_notices_mxid = UserID(mxid_localpart, self.server_name).to_string() + self.server_notices_mxid = UserID( + mxid_localpart, self.root.server.server_name + ).to_string() self.server_notices_mxid_display_name = c.get("system_mxid_display_name", None) self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None) # todo: i18n diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 524a7ff3aa..11a9b76aa0 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -103,8 +103,10 @@ class SSOConfig(Config): # the client's. # public_baseurl is an optional setting, so we only add the fallback's URL to the # list if it's provided (because we can't figure out what that URL is otherwise). - if self.public_baseurl: - login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + if self.root.server.public_baseurl: + login_fallback_url = ( + self.root.server.public_baseurl + "_matrix/static/client/login" + ) self.sso_client_whitelist.append(login_fallback_url) def generate_config_section(self, **kwargs): diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 5a5f124ddf..87e415df75 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -67,12 +67,8 @@ class AccountValidityHandler: and self._account_validity_renew_by_email_enabled ): # Don't do email-specific configuration if renewal by email is disabled. - self._template_html = ( - hs.config.account_validity.account_validity_template_html - ) - self._template_text = ( - hs.config.account_validity.account_validity_template_text - ) + self._template_html = hs.config.email.account_validity_template_html + self._template_text = hs.config.email.account_validity_template_text self._renew_email_subject = ( hs.config.account_validity.account_validity_renew_email_subject ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0b79dbcf8d..c05461bf2a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1499,8 +1499,11 @@ class RoomMemberMasterHandler(RoomMemberHandler): if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") - check_complexity = self.hs.config.limit_remote_rooms.enabled - if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join: + check_complexity = self.hs.config.server.limit_remote_rooms.enabled + if ( + check_complexity + and self.hs.config.server.limit_remote_rooms.admins_can_join + ): check_complexity = not await self.auth.is_server_admin(user) if check_complexity: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 37769ace48..961c17762e 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -117,7 +117,7 @@ class ReplicationDataHandler: self._instance_name = hs.get_instance_name() self._typing_handler = hs.get_typing_handler() - self._notify_pushers = hs.config.start_pushers + self._notify_pushers = hs.config.worker.start_pushers self._pusher_pool = hs.get_pusherpool() self._presence_handler = hs.get_presence_handler() diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index d64d1dbacd..6aa9318027 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -171,7 +171,10 @@ class ReplicationCommandHandler: if hs.config.worker.worker_app is not None: continue - if stream.NAME == FederationStream.NAME and hs.config.send_federation: + if ( + stream.NAME == FederationStream.NAME + and hs.config.worker.send_federation + ): # We only support federation stream if federation sending # has been disabled on the master. continue @@ -225,7 +228,7 @@ class ReplicationCommandHandler: self._is_master = hs.config.worker.worker_app is None self._federation_sender = None - if self._is_master and not hs.config.send_federation: + if self._is_master and not hs.config.worker.send_federation: self._federation_sender = hs.get_federation_sender() self._server_notices_sender = None diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index c9ad35a3ad..9c15a04338 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -48,7 +48,7 @@ class AuthRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.captcha.recaptcha_template - self.terms_template = hs.config.terms_template + self.terms_template = hs.config.consent.terms_template self.registration_token_template = ( hs.config.registration.registration_token_template ) diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index ecebc46e8d..6f796d5e50 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -61,7 +61,9 @@ class PushRuleRestServlet(RestServlet): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None - self._users_new_default_push_rules = hs.config.users_new_default_push_rules + self._users_new_default_push_rules = ( + hs.config.server.users_new_default_push_rules + ) async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index a7fb8cd848..b81e33964a 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -101,7 +101,9 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) - self._users_new_default_push_rules = hs.config.users_new_default_push_rules + self._users_new_default_push_rules = ( + hs.config.server.users_new_default_push_rules + ) @abc.abstractmethod def get_max_push_rules_stream_id(self): diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index de262fbf5a..7de4ad7f9b 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1778,7 +1778,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors + self._ignore_unknown_session_error = ( + hs.config.server.request_token_inhibit_3pid_errors + ) self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") diff --git a/tests/config/test_base.py b/tests/config/test_base.py index baa5313fb3..6a52f862f4 100644 --- a/tests/config/test_base.py +++ b/tests/config/test_base.py @@ -14,23 +14,28 @@ import os.path import tempfile +from unittest.mock import Mock from synapse.config import ConfigError +from synapse.config._base import Config from synapse.util.stringutils import random_string from tests import unittest -class BaseConfigTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): - self.hs = hs +class BaseConfigTestCase(unittest.TestCase): + def setUp(self): + # The root object needs a server property with a public_baseurl. + root = Mock() + root.server.public_baseurl = "http://test" + self.config = Config(root) def test_loading_missing_templates(self): # Use a temporary directory that exists on the system, but that isn't likely to # contain template files with tempfile.TemporaryDirectory() as tmp_dir: # Attempt to load an HTML template from our custom template directory - template = self.hs.config.read_templates(["sso_error.html"], (tmp_dir,))[0] + template = self.config.read_templates(["sso_error.html"], (tmp_dir,))[0] # If no errors, we should've gotten the default template instead @@ -60,7 +65,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Attempt to load the template from our custom template directory template = ( - self.hs.config.read_templates([template_filename], (tmp_dir,)) + self.config.read_templates([template_filename], (tmp_dir,)) )[0] # Render the template @@ -97,7 +102,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Retrieve the template. template = ( - self.hs.config.read_templates( + self.config.read_templates( [template_filename], (td.name for td in tempdirs), ) @@ -118,7 +123,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Retrieve the template. template = ( - self.hs.config.read_templates( + self.config.read_templates( [other_template_name], (td.name for td in tempdirs), ) @@ -134,6 +139,6 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): def test_loading_template_from_nonexistent_custom_directory(self): with self.assertRaises(ConfigError): - self.hs.config.read_templates( + self.config.read_templates( ["some_filename.html"], ("a_nonexistent_directory",) ) diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index 857d9cd096..f518abdb7a 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -12,39 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.config._base import Config, RootConfig from synapse.config.cache import CacheConfig, add_resizable_cache from synapse.util.caches.lrucache import LruCache from tests.unittest import TestCase -class FakeServer(Config): - section = "server" - - -class TestConfig(RootConfig): - config_classes = [FakeServer, CacheConfig] - - class CacheConfigTests(TestCase): def setUp(self): # Reset caches before each test - TestConfig().caches.reset() + self.config = CacheConfig() + + def tearDown(self): + self.config.reset() def test_individual_caches_from_environ(self): """ Individual cache factors will be loaded from the environment. """ config = {} - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_NOT_CACHE": "BLAH", } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0}) + self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0}) def test_config_overrides_environ(self): """ @@ -52,15 +45,14 @@ class CacheConfigTests(TestCase): over those in the config. """ config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_CACHE_FACTOR_FOO": 1, } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual( - dict(t.caches.cache_factors), + dict(self.config.cache_factors), {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, ) @@ -76,8 +68,7 @@ class CacheConfigTests(TestCase): self.assertEqual(cache.max_size, 50) config = {"caches": {"per_cache_factors": {"foo": 3}}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config) self.assertEqual(cache.max_size, 300) @@ -88,8 +79,7 @@ class CacheConfigTests(TestCase): there is one. """ config = {"caches": {"per_cache_factors": {"foo": 2}}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -106,8 +96,7 @@ class CacheConfigTests(TestCase): self.assertEqual(cache.max_size, 50) config = {"caches": {"global_factor": 4}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual(cache.max_size, 400) @@ -118,8 +107,7 @@ class CacheConfigTests(TestCase): is no per-cache factor. """ config = {"caches": {"global_factor": 1.5}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -133,12 +121,11 @@ class CacheConfigTests(TestCase): "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2} } } - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_CACHE_A": "2", "SYNAPSE_CACHE_FACTOR_CACHE_B": 3, } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache_a = LruCache(100) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) @@ -158,11 +145,10 @@ class CacheConfigTests(TestCase): """ config = {"caches": {"event_cache_size": "10k"}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache( - max_size=t.caches.event_cache_size, + max_size=self.config.event_cache_size, apply_cache_factor_from_config=False, ) add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 8e49ca26d9..59635de205 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -49,7 +49,7 @@ class ConfigLoadingTestCase(unittest.TestCase): config = HomeServerConfig.load_config("", ["-c", self.file]) self.assertTrue( - hasattr(config, "macaroon_secret_key"), + hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) if len(config.key.macaroon_secret_key) < 5: @@ -60,7 +60,7 @@ class ConfigLoadingTestCase(unittest.TestCase): config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) self.assertTrue( - hasattr(config, "macaroon_secret_key"), + hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) if len(config.key.macaroon_secret_key) < 5: @@ -74,8 +74,12 @@ class ConfigLoadingTestCase(unittest.TestCase): config1 = HomeServerConfig.load_config("", ["-c", self.file]) config2 = HomeServerConfig.load_config("", ["-c", self.file]) config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) - self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key) - self.assertEqual(config1.macaroon_secret_key, config3.macaroon_secret_key) + self.assertEqual( + config1.key.macaroon_secret_key, config2.key.macaroon_secret_key + ) + self.assertEqual( + config1.key.macaroon_secret_key, config3.key.macaroon_secret_key + ) def test_disable_registration(self): self.generate_config() diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index b6bc1876b5..9ba5781573 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -42,9 +42,9 @@ class TLSConfigTests(TestCase): """ config = {} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") def test_tls_client_minimum_set(self): """ @@ -52,29 +52,29 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") config = {"federation_client_minimum_tls_version": 1.1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.1") config = {"federation_client_minimum_tls_version": 1.2} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.2") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2") # Also test a string version config = {"federation_client_minimum_tls_version": "1"} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") config = {"federation_client_minimum_tls_version": "1.2"} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.2") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2") def test_tls_client_minimum_1_point_3_missing(self): """ @@ -91,7 +91,7 @@ class TLSConfigTests(TestCase): config = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() with self.assertRaises(ConfigError) as e: - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual( e.exception.args[0], ( @@ -112,8 +112,8 @@ class TLSConfigTests(TestCase): config = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.3") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3") def test_tls_client_minimum_set_passed_through_1_2(self): """ @@ -121,7 +121,7 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1.2} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) options = _get_ssl_context_options(cf._verify_ssl_context) @@ -137,7 +137,7 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) options = _get_ssl_context_options(cf._verify_ssl_context) @@ -159,7 +159,7 @@ class TLSConfigTests(TestCase): } t = TestConfig() e = self.assertRaises( - ConfigError, t.read_config, config, config_dir_path="", data_dir_path="" + ConfigError, t.tls.read_config, config, config_dir_path="", data_dir_path="" ) self.assertIn("IDNA domain names", str(e)) @@ -174,7 +174,7 @@ class TLSConfigTests(TestCase): ] } t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index cf9748f218..f26d5acf9c 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -126,7 +126,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.db_pool = database._db_pool self.engine = database.engine - db_config = hs.config.get_single_database() + db_config = hs.config.database.get_single_database() self.store = TestTransactionStore( database, make_conn(db_config, self.engine, "test"), hs ) diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py index 6ff3ebb137..ace82cbf42 100644 --- a/tests/storage/test_txn_limit.py +++ b/tests/storage/test_txn_limit.py @@ -22,7 +22,7 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase): return self.setup_test_homeserver(db_txn_limit=1000) def test_config(self): - db_config = self.hs.config.get_single_database() + db_config = self.hs.config.database.get_single_database() self.assertEqual(db_config.config["txn_limit"], 1000) def test_select(self): -- cgit 1.5.1 From 4e5162106436f3fddd12561d316d19fd23148800 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 6 Oct 2021 17:18:13 +0200 Subject: Add a spamchecker method to allow or deny 3pid invites (#10894) This is in the context of creating new module callbacks that modules in https://github.com/matrix-org/synapse-dinsic can use, in an effort to reconcile the spam checker API in synapse-dinsic with the one in mainline. Note that a module callback already exists for 3pid invites (https://matrix-org.github.io/synapse/develop/modules/third_party_rules_callbacks.html#check_threepid_can_be_invited) but it doesn't check whether the sender of the invite is allowed to send it. --- changelog.d/10894.feature | 1 + docs/modules/spam_checker_callbacks.md | 35 +++++++++++++++++ synapse/events/spamcheck.py | 35 +++++++++++++++++ synapse/handlers/room_member.py | 12 ++++++ tests/rest/client/test_rooms.py | 70 ++++++++++++++++++++++++++++++++++ 5 files changed, 153 insertions(+) create mode 100644 changelog.d/10894.feature (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/10894.feature b/changelog.d/10894.feature new file mode 100644 index 0000000000..a4f968bed1 --- /dev/null +++ b/changelog.d/10894.feature @@ -0,0 +1 @@ +Add a `user_may_send_3pid_invite` spam checker callback for modules to allow or deny 3PID invites. diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 92376df993..787e99074a 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -44,6 +44,41 @@ Called when processing an invitation. The module must return a `bool` indicating the inviter can invite the invitee to the given room. Both inviter and invitee are represented by their Matrix user ID (e.g. `@alice:example.com`). +### `user_may_send_3pid_invite` + +```python +async def user_may_send_3pid_invite( + inviter: str, + medium: str, + address: str, + room_id: str, +) -> bool +``` + +Called when processing an invitation using a third-party identifier (also called a 3PID, +e.g. an email address or a phone number). The module must return a `bool` indicating +whether the inviter can invite the invitee to the given room. + +The inviter is represented by their Matrix user ID (e.g. `@alice:example.com`), and the +invitee is represented by its medium (e.g. "email") and its address +(e.g. `alice@example.com`). See [the Matrix specification](https://matrix.org/docs/spec/appendices#pid-types) +for more information regarding third-party identifiers. + +For example, a call to this callback to send an invitation to the email address +`alice@example.com` would look like this: + +```python +await user_may_send_3pid_invite( + "@bob:example.com", # The inviter's user ID + "email", # The medium of the 3PID to invite + "alice@example.com", # The address of the 3PID to invite + "!some_room:example.com", # The ID of the room to send the invite into +) +``` + +**Note**: If the third-party identifier is already associated with a matrix user ID, +[`user_may_invite`](#user_may_invite) will be used instead. + ### `user_may_create_room` ```python diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index ec8863e397..ae4c8ab257 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -46,6 +46,7 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[ ] USER_MAY_JOIN_ROOM_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]] +USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bool]] USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]] USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK = Callable[ [str, List[str], List[Dict[str, str]]], Awaitable[bool] @@ -168,6 +169,9 @@ class SpamChecker: self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = [] self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = [] self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = [] + self._user_may_send_3pid_invite_callbacks: List[ + USER_MAY_SEND_3PID_INVITE_CALLBACK + ] = [] self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = [] self._user_may_create_room_with_invites_callbacks: List[ USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK @@ -191,6 +195,7 @@ class SpamChecker: check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None, user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None, user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, + user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None, user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, user_may_create_room_with_invites: Optional[ USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK @@ -215,6 +220,11 @@ class SpamChecker: if user_may_invite is not None: self._user_may_invite_callbacks.append(user_may_invite) + if user_may_send_3pid_invite is not None: + self._user_may_send_3pid_invite_callbacks.append( + user_may_send_3pid_invite, + ) + if user_may_create_room is not None: self._user_may_create_room_callbacks.append(user_may_create_room) @@ -304,6 +314,31 @@ class SpamChecker: return True + async def user_may_send_3pid_invite( + self, inviter_userid: str, medium: str, address: str, room_id: str + ) -> bool: + """Checks if a given user may invite a given threepid into the room + + If this method returns false, the threepid invite will be rejected. + + Note that if the threepid is already associated with a Matrix user ID, Synapse + will call user_may_invite with said user ID instead. + + Args: + inviter_userid: The user ID of the sender of the invitation + medium: The 3PID's medium (e.g. "email") + address: The 3PID's address (e.g. "alice@example.com") + room_id: The room ID + + Returns: + True if the user may send the invite, otherwise False + """ + for callback in self._user_may_send_3pid_invite_callbacks: + if await callback(inviter_userid, medium, address, room_id) is False: + return False + + return True + async def user_may_create_room(self, userid: str) -> bool: """Checks if a given user may create a room diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index c05461bf2a..eef337feeb 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1299,10 +1299,22 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if invitee: # Note that update_membership with an action of "invite" can raise # a ShadowBanError, but this was done above already. + # We don't check the invite against the spamchecker(s) here (through + # user_may_invite) because we'll do it further down the line anyway (in + # update_membership_locked). _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: + # Check if the spamchecker(s) allow this invite to go through. + if not await self.spam_checker.user_may_send_3pid_invite( + inviter_userid=requester.user.to_string(), + medium=medium, + address=address, + room_id=room_id, + ): + raise SynapseError(403, "Cannot send threepid invite") + stream_id = await self._make_and_store_3pid_invite( requester, id_server, diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index a41ec6a98f..376853fd65 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2531,3 +2531,73 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): """An alias which does not point to the room raises a SynapseError.""" self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) + + +class ThreepidInviteTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("thomas", "hackme") + self.tok = self.login("thomas", "hackme") + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_threepid_invite_spamcheck(self): + # Mock a few functions to prevent the test from failing due to failing to talk to + # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we + # can check its call_count later on during the test. + make_invite_mock = Mock(return_value=make_awaitable(0)) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock + self.hs.get_identity_handler().lookup_3pid = Mock( + return_value=make_awaitable(None), + ) + + # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it + # allow everything for now. + mock = Mock(return_value=make_awaitable(True)) + self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock) + + # Send a 3PID invite into the room and check that it succeeded. + email_to_invite = "teresa@example.com" + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEquals(channel.code, 200) + + # Check that the callback was called with the right params. + mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id) + + # Check that the call to send the invite was made. + make_invite_mock.assert_called_once() + + # Now change the return value of the callback to deny any invite and test that + # we can't send the invite. + mock.return_value = make_awaitable(False) + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEquals(channel.code, 403) + + # Also check that it stopped before calling _make_and_store_3pid_invite. + make_invite_mock.assert_called_once() -- cgit 1.5.1 From eb9ddc8c2e807e691fd1820f88f7c0bf43822661 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 8 Oct 2021 07:44:43 -0400 Subject: Remove the deprecated BaseHandler. (#11005) The shared ratelimit function was replaced with a dedicated RequestRatelimiter class (accessible from the HomeServer object). Other properties were copied to each sub-class that inherited from BaseHandler. --- changelog.d/11005.misc | 1 + synapse/api/ratelimiting.py | 86 +++++++++++++++++++++++ synapse/handlers/_base.py | 120 --------------------------------- synapse/handlers/admin.py | 7 +- synapse/handlers/auth.py | 8 +-- synapse/handlers/deactivate_account.py | 6 +- synapse/handlers/device.py | 10 +-- synapse/handlers/directory.py | 9 ++- synapse/handlers/events.py | 12 ++-- synapse/handlers/federation.py | 6 +- synapse/handlers/identity.py | 7 +- synapse/handlers/initial_sync.py | 8 +-- synapse/handlers/message.py | 7 +- synapse/handlers/profile.py | 11 +-- synapse/handlers/read_marker.py | 5 +- synapse/handlers/receipts.py | 6 +- synapse/handlers/register.py | 9 ++- synapse/handlers/room.py | 15 +++-- synapse/handlers/room_list.py | 7 +- synapse/handlers/room_member.py | 8 +-- synapse/handlers/saml.py | 7 +- synapse/handlers/search.py | 9 +-- synapse/handlers/set_password.py | 6 +- synapse/server.py | 11 ++- 24 files changed, 166 insertions(+), 215 deletions(-) create mode 100644 changelog.d/11005.misc delete mode 100644 synapse/handlers/_base.py (limited to 'synapse/handlers/room_member.py') diff --git a/changelog.d/11005.misc b/changelog.d/11005.misc new file mode 100644 index 0000000000..a893591971 --- /dev/null +++ b/changelog.d/11005.misc @@ -0,0 +1 @@ +Remove the deprecated `BaseHandler` object. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index cbdd74025b..e8964097d3 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,6 +17,7 @@ from collections import OrderedDict from typing import Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.config.ratelimiting import RateLimitConfig from synapse.storage.databases.main import DataStore from synapse.types import Requester from synapse.util import Clock @@ -233,3 +234,88 @@ class Ratelimiter: raise LimitExceededError( retry_after_ms=int(1000 * (time_allowed - time_now_s)) ) + + +class RequestRatelimiter: + def __init__( + self, + store: DataStore, + clock: Clock, + rc_message: RateLimitConfig, + rc_admin_redaction: Optional[RateLimitConfig], + ): + self.store = store + self.clock = clock + + # The rate_hz and burst_count are overridden on a per-user basis + self.request_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=0, burst_count=0 + ) + self._rc_message = rc_message + + # Check whether ratelimiting room admin message redaction is enabled + # by the presence of rate limits in the config + if rc_admin_redaction: + self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( + store=self.store, + clock=self.clock, + rate_hz=rc_admin_redaction.per_second, + burst_count=rc_admin_redaction.burst_count, + ) + else: + self.admin_redaction_ratelimiter = None + + async def ratelimit( + self, + requester: Requester, + update: bool = True, + is_admin_redaction: bool = False, + ) -> None: + """Ratelimits requests. + + Args: + requester + update: Whether to record that a request is being processed. + Set to False when doing multiple checks for one request (e.g. + to check up front if we would reject the request), and set to + True for the last call for a given request. + is_admin_redaction: Whether this is a room admin/moderator + redacting an event. If so then we may apply different + ratelimits depending on config. + + Raises: + LimitExceededError if the request should be ratelimited + """ + user_id = requester.user.to_string() + + # The AS user itself is never rate limited. + app_service = self.store.get_app_service_by_user_id(user_id) + if app_service is not None: + return # do not ratelimit app service senders + + messages_per_second = self._rc_message.per_second + burst_count = self._rc_message.burst_count + + # Check if there is a per user override in the DB. + override = await self.store.get_ratelimit_for_user(user_id) + if override: + # If overridden with a null Hz then ratelimiting has been entirely + # disabled for the user + if not override.messages_per_second: + return + + messages_per_second = override.messages_per_second + burst_count = override.burst_count + + if is_admin_redaction and self.admin_redaction_ratelimiter: + # If we have separate config for admin redactions, use a separate + # ratelimiter as to not have user_ids clash + await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) + else: + # Override rate and burst count per-user + await self.request_ratelimiter.ratelimit( + requester, + rate_hz=messages_per_second, + burst_count=burst_count, + update=update, + ) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py deleted file mode 100644 index 0ccef884e7..0000000000 --- a/synapse/handlers/_base.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2014 - 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING, Optional - -from synapse.api.ratelimiting import Ratelimiter -from synapse.types import Requester - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class BaseHandler: - """ - Common base class for the event handlers. - - Deprecated: new code should not use this. Instead, Handler classes should define the - fields they actually need. The utility methods should either be factored out to - standalone helper functions, or to different Handler classes. - """ - - def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() - self.auth = hs.get_auth() - self.notifier = hs.get_notifier() - self.state_handler = hs.get_state_handler() - self.distributor = hs.get_distributor() - self.clock = hs.get_clock() - self.hs = hs - - # The rate_hz and burst_count are overridden on a per-user basis - self.request_ratelimiter = Ratelimiter( - store=self.store, clock=self.clock, rate_hz=0, burst_count=0 - ) - self._rc_message = self.hs.config.ratelimiting.rc_message - - # Check whether ratelimiting room admin message redaction is enabled - # by the presence of rate limits in the config - if self.hs.config.ratelimiting.rc_admin_redaction: - self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( - store=self.store, - clock=self.clock, - rate_hz=self.hs.config.ratelimiting.rc_admin_redaction.per_second, - burst_count=self.hs.config.ratelimiting.rc_admin_redaction.burst_count, - ) - else: - self.admin_redaction_ratelimiter = None - - self.server_name = hs.hostname - - self.event_builder_factory = hs.get_event_builder_factory() - - async def ratelimit( - self, - requester: Requester, - update: bool = True, - is_admin_redaction: bool = False, - ) -> None: - """Ratelimits requests. - - Args: - requester - update: Whether to record that a request is being processed. - Set to False when doing multiple checks for one request (e.g. - to check up front if we would reject the request), and set to - True for the last call for a given request. - is_admin_redaction: Whether this is a room admin/moderator - redacting an event. If so then we may apply different - ratelimits depending on config. - - Raises: - LimitExceededError if the request should be ratelimited - """ - user_id = requester.user.to_string() - - # The AS user itself is never rate limited. - app_service = self.store.get_app_service_by_user_id(user_id) - if app_service is not None: - return # do not ratelimit app service senders - - messages_per_second = self._rc_message.per_second - burst_count = self._rc_message.burst_count - - # Check if there is a per user override in the DB. - override = await self.store.get_ratelimit_for_user(user_id) - if override: - # If overridden with a null Hz then ratelimiting has been entirely - # disabled for the user - if not override.messages_per_second: - return - - messages_per_second = override.messages_per_second - burst_count = override.burst_count - - if is_admin_redaction and self.admin_redaction_ratelimiter: - # If we have separate config for admin redactions, use a separate - # ratelimiter as to not have user_ids clash - await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) - else: - # Override rate and burst count per-user - await self.request_ratelimiter.ratelimit( - requester, - rate_hz=messages_per_second, - burst_count=burst_count, - update=update, - ) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index bfa7f2c545..a53cd62d3c 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -21,18 +21,15 @@ from synapse.events import EventBase from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class AdminHandler(BaseHandler): +class AdminHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2d0f3d566c..f4612a5b92 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -52,7 +52,6 @@ from synapse.api.errors import ( UserDeactivatedError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.handlers._base import BaseHandler from synapse.handlers.ui_auth import ( INTERACTIVE_AUTH_CHECKERS, UIAuthSessionDataConstants, @@ -186,12 +185,13 @@ class LoginTokenAttributes: auth_provider_id = attr.ib(type=str) -class AuthHandler(BaseHandler): +class AuthHandler: SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.clock = hs.get_clock() self.checkers: Dict[str, UserInteractiveAuthChecker] = {} for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 12bdca7445..e88c3c27ce 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -19,19 +19,17 @@ from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import Requester, UserID, create_requester -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class DeactivateAccountHandler(BaseHandler): +class DeactivateAccountHandler: """Handler which deals with deactivating user accounts.""" def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 35334725d7..75e6019760 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -40,8 +40,6 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -50,14 +48,16 @@ logger = logging.getLogger(__name__) MAX_DEVICE_DISPLAY_NAME_LEN = 100 -class DeviceWorkerHandler(BaseHandler): +class DeviceWorkerHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.clock = hs.get_clock() self.hs = hs + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() self.state = hs.get_state_handler() self.state_store = hs.get_storage().state self._auth_handler = hs.get_auth_handler() + self.server_name = hs.hostname @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 9078781d5a..14ed7d9879 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -31,18 +31,16 @@ from synapse.appservice import ApplicationService from synapse.storage.databases.main.directory import RoomAliasMapping from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class DirectoryHandler(BaseHandler): +class DirectoryHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.auth = hs.get_auth() + self.hs = hs self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() @@ -51,6 +49,7 @@ class DirectoryHandler(BaseHandler): self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.server.require_membership_for_aliases self.third_party_event_rules = hs.get_third_party_event_rules() + self.server_name = hs.hostname self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 4b3f037072..1f64534a8a 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -25,8 +25,6 @@ from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -34,11 +32,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class EventStreamHandler(BaseHandler): +class EventStreamHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() self.clock = hs.get_clock() + self.hs = hs self.notifier = hs.get_notifier() self.state = hs.get_state_handler() @@ -138,9 +136,9 @@ class EventStreamHandler(BaseHandler): return chunk -class EventHandler(BaseHandler): +class EventHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() self.storage = hs.get_storage() async def get_event( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 043ca4a224..3e341bd287 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -53,7 +53,6 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.federation.federation_client import InvalidResponseError -from synapse.handlers._base import BaseHandler from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, @@ -78,15 +77,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class FederationHandler(BaseHandler): +class FederationHandler: """Handles general incoming federation requests Incoming events are *not* handled here, for which see FederationEventHandler. """ def __init__(self, hs: "HomeServer"): - super().__init__(hs) - self.hs = hs self.store = hs.get_datastore() @@ -99,6 +96,7 @@ class FederationHandler(BaseHandler): self.is_mine_id = hs.is_mine_id self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self.event_builder_factory = hs.get_event_builder_factory() self._event_auth_handler = hs.get_event_auth_handler() self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self.config = hs.config diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c881475c25..9c319b5383 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -39,8 +39,6 @@ from synapse.util.stringutils import ( valid_id_server_location, ) -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -49,10 +47,9 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" -class IdentityHandler(BaseHandler): +class IdentityHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 9ad39a65d8..d4e4556155 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -31,8 +31,6 @@ from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -40,9 +38,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class InitialSyncHandler(BaseHandler): +class InitialSyncHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.state_handler = hs.get_state_handler() self.hs = hs self.state = hs.get_state_handler() self.clock = hs.get_clock() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ccd7827207..4de9f4b828 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -62,8 +62,6 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.events.third_party_rules import ThirdPartyEventRules from synapse.server import HomeServer @@ -433,8 +431,7 @@ class EventCreationHandler: self.send_event = ReplicationSendEventRestServlet.make_client(hs) - # This is only used to get at ratelimit function - self.base_handler = BaseHandler(hs) + self.request_ratelimiter = hs.get_request_ratelimiter() # We arbitrarily limit concurrent event creation for a room to 5. # This is to stop us from diverging history *too* much. @@ -1322,7 +1319,7 @@ class EventCreationHandler: original_event and event.sender != original_event.sender ) - await self.base_handler.ratelimit( + await self.request_ratelimiter.ratelimit( requester, is_admin_redaction=is_admin_redaction ) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 2e19706c69..e6c3cf585b 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -32,8 +32,6 @@ from synapse.types import ( get_domain_from_id, ) -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -43,7 +41,7 @@ MAX_DISPLAYNAME_LEN = 256 MAX_AVATAR_URL_LEN = 1000 -class ProfileHandler(BaseHandler): +class ProfileHandler: """Handles fetching and updating user profile information. ProfileHandler can be instantiated directly on workers and will @@ -54,7 +52,9 @@ class ProfileHandler(BaseHandler): PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.hs = hs self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -62,6 +62,7 @@ class ProfileHandler(BaseHandler): ) self.user_directory_handler = hs.get_user_directory_handler() + self.request_ratelimiter = hs.get_request_ratelimiter() if hs.config.worker.run_background_tasks: self.clock.looping_call( @@ -346,7 +347,7 @@ class ProfileHandler(BaseHandler): if not self.hs.is_mine(target_user): return - await self.ratelimit(requester) + await self.request_ratelimiter.ratelimit(requester) # Do not actually update the room state for shadow-banned users. if requester.shadow_banned: diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index bd8160e7ed..58593e570e 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -17,17 +17,14 @@ from typing import TYPE_CHECKING from synapse.util.async_helpers import Linearizer -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class ReadMarkerHandler(BaseHandler): +class ReadMarkerHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) self.server_name = hs.config.server.server_name self.store = hs.get_datastore() self.account_data_handler = hs.get_account_data_handler() diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index f21f33ada2..374e961e3b 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -16,7 +16,6 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse.api.constants import ReadReceiptEventFields from synapse.appservice import ApplicationService -from synapse.handlers._base import BaseHandler from synapse.streams import EventSource from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id @@ -26,10 +25,9 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ReceiptsHandler(BaseHandler): +class ReceiptsHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.notifier = hs.get_notifier() self.server_name = hs.config.server.server_name self.store = hs.get_datastore() self.event_auth_handler = hs.get_event_auth_handler() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 441af7a848..a0e6a01775 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -41,8 +41,6 @@ from synapse.spam_checker_api import RegistrationBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -85,9 +83,10 @@ class LoginDict(TypedDict): refresh_token: Optional[str] -class RegistrationHandler(BaseHandler): +class RegistrationHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.clock = hs.get_clock() self.hs = hs self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() @@ -515,7 +514,7 @@ class RegistrationHandler(BaseHandler): # we don't have a local user in the room to craft up an invite with. requires_invite = await self.store.is_host_joined( room_id, - self.server_name, + self._server_name, ) if requires_invite: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d40dbd761d..7072bca1fc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -76,8 +76,6 @@ from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -88,15 +86,18 @@ id_server_scheme = "https://" FIVE_MINUTES_IN_MS = 5 * 60 * 1000 -class RoomCreationHandler(BaseHandler): +class RoomCreationHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.hs = hs self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self._event_auth_handler = hs.get_event_auth_handler() self.config = hs.config + self.request_ratelimiter = hs.get_request_ratelimiter() # Room state based off defined presets self._presets_dict: Dict[str, Dict[str, Any]] = { @@ -162,7 +163,7 @@ class RoomCreationHandler(BaseHandler): Raises: ShadowBanError if the requester is shadow-banned. """ - await self.ratelimit(requester) + await self.request_ratelimiter.ratelimit(requester) user_id = requester.user.to_string() @@ -665,7 +666,7 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(403, "You are not permitted to create rooms") if ratelimit: - await self.ratelimit(requester) + await self.request_ratelimiter.ratelimit(requester) room_version_id = config.get( "room_version", self.config.server.default_room_version.identifier diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index c3d4199ed1..ba7a14d651 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -36,8 +36,6 @@ from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.response_cache import ResponseCache -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -49,9 +47,10 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) -class RoomListHandler(BaseHandler): +class RoomListHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.hs = hs self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]] diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index eef337feeb..74e6c7eca6 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -51,8 +51,6 @@ from synapse.types import ( from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_left_room -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -118,9 +116,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, ) - # This is only used to get at the ratelimit function. It's fine there are - # multiple of these as it doesn't store state. - self.base_handler = BaseHandler(hs) + self.request_ratelimiter = hs.get_request_ratelimiter() @abc.abstractmethod async def _remote_join( @@ -1275,7 +1271,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. - await self.base_handler.ratelimit(requester) + await self.request_ratelimiter.ratelimit(requester) can_invite = await self.third_party_event_rules.check_threepid_can_be_invited( medium, address, room_id diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 2fed9f377a..727d75a50c 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -22,7 +22,6 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError -from synapse.handlers._base import BaseHandler from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest @@ -51,9 +50,11 @@ class Saml2SessionData: ui_auth_session_id: Optional[str] = None -class SamlHandler(BaseHandler): +class SamlHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.server_name = hs.hostname self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config) self._saml_idp_entityid = hs.config.saml2.saml2_idp_entityid diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 6d3333ee00..a3ffa26be8 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -26,17 +26,18 @@ from synapse.storage.state import StateFilter from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class SearchHandler(BaseHandler): +class SearchHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.state_handler = hs.get_state_handler() + self.clock = hs.get_clock() + self.hs = hs self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index a63fac8283..706ad72761 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -17,19 +17,17 @@ from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class SetPasswordHandler(BaseHandler): +class SetPasswordHandler: """Handler which deals with changing user account passwords""" def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/server.py b/synapse/server.py index 637eb15b78..0783df41d4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -39,7 +39,7 @@ from twisted.web.resource import IResource from synapse.api.auth import Auth from synapse.api.filtering import Filtering -from synapse.api.ratelimiting import Ratelimiter +from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.config.homeserver import HomeServerConfig @@ -816,3 +816,12 @@ class HomeServer(metaclass=abc.ABCMeta): def should_send_federation(self) -> bool: "Should this server be sending federation traffic directly?" return self.config.worker.send_federation + + @cache_in_self + def get_request_ratelimiter(self) -> RequestRatelimiter: + return RequestRatelimiter( + self.get_datastore(), + self.get_clock(), + self.config.ratelimiting.rc_message, + self.config.ratelimiting.rc_admin_redaction, + ) -- cgit 1.5.1