From eb2c7e51c460a83b7880eefc66eb9ca6a8adab94 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 28 Sep 2021 09:24:40 -0400 Subject: Clean-up type hints in server config (#10915) By using attrs instead of dicts to store configuration. Also updates some of the attrs classes to use proper type hints and auto_attribs. --- synapse/handlers/pagination.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 08b93b3ec1..a5301ece6f 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -92,16 +92,16 @@ class PaginationHandler: if hs.config.worker.run_background_tasks and hs.config.retention_enabled: # Run the purge jobs described in the configuration file. - for job in hs.config.retention_purge_jobs: + for job in hs.config.server.retention_purge_jobs: logger.info("Setting up purge job with config: %s", job) self.clock.looping_call( run_as_background_process, - job["interval"], + job.interval, "purge_history_for_rooms_in_range", self.purge_history_for_rooms_in_range, - job["shortest_max_lifetime"], - job["longest_max_lifetime"], + job.shortest_max_lifetime, + job.longest_max_lifetime, ) async def purge_history_for_rooms_in_range( -- cgit 1.5.1 From 2622b28c5cbe38c60c556544aa7502a8684ee60b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 28 Sep 2021 15:25:07 +0100 Subject: Inline `_check_event_auth` for outliers (#10926) * Inline `_check_event_auth` for outliers When we are persisting an outlier, most of `_check_event_auth` is redundant: * `_update_auth_events_and_context_for_auth` does nothing, because the `input_auth_events` are (now) exactly the event's auth_events, which means that `missing_auth` is empty. * we don't care about soft-fail, kicking guest users or `send_on_behalf_of` for outliers ... so the only thing that matters is the auth itself, so let's just do that. * `_auth_and_persist_fetched_events_inner`: de-async `prep` `prep` no longer calls any `async` methods, so let's make it synchronous. * Simplify `_check_event_auth` We no longer need to support outliers here, which makes things rather simpler. * changelog * lint --- changelog.d/10896.misc | 2 +- changelog.d/10926.misc | 1 + synapse/handlers/federation_event.py | 93 ++++++++++++++---------------------- tests/test_federation.py | 1 - 4 files changed, 38 insertions(+), 59 deletions(-) create mode 100644 changelog.d/10926.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10896.misc b/changelog.d/10896.misc index 41de995842..9a765435db 100644 --- a/changelog.d/10896.misc +++ b/changelog.d/10896.misc @@ -1 +1 @@ - Clean up some of the federation event authentication code for clarity. +Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/10926.misc b/changelog.d/10926.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10926.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 01fd841122..2c4644b4a3 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -68,11 +68,7 @@ from synapse.types import ( UserID, get_domain_from_id, ) -from synapse.util.async_helpers import ( - Linearizer, - concurrently_execute, - yieldable_gather_results, -) +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.iterutils import batch_iter from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import shortstr @@ -1189,7 +1185,10 @@ class FederationEventHandler: allow_rejected=True, ) - async def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]: + room_version = await self._store.get_room_version_id(room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]: with nested_logging_context(suffix=event.event_id): auth = {} for auth_event_id in event.auth_event_ids(): @@ -1207,17 +1206,15 @@ class FederationEventHandler: auth[(ae.type, ae.state_key)] = ae context = EventContext.for_outlier() - context = await self._check_event_auth( - origin, - event, - context, - claimed_auth_event_map=auth, - ) + try: + event_auth.check(room_version_obj, event, auth_events=auth) + except AuthError as e: + logger.warning("Rejecting %r because %s", event, e) + context.rejected = RejectedReason.AUTH_ERROR + return event, context - events_to_persist = ( - x for x in await yieldable_gather_results(prep, fetched_events) if x - ) + events_to_persist = (x for x in (prep(event) for event in fetched_events) if x) await self.persist_events_and_notify(room_id, tuple(events_to_persist)) async def _check_event_auth( @@ -1226,7 +1223,6 @@ class FederationEventHandler: event: EventBase, context: EventContext, state: Optional[Iterable[EventBase]] = None, - claimed_auth_event_map: Optional[StateMap[EventBase]] = None, backfilled: bool = False, ) -> EventContext: """ @@ -1242,42 +1238,36 @@ class FederationEventHandler: The state events used to check the event for soft-fail. If this is not provided the current state events will be used. - claimed_auth_event_map: - A map of (type, state_key) => event for the event's claimed auth_events. - Possibly including events that were rejected, or are in the wrong room. - - Only populated when populating outliers. - backfilled: True if the event was backfilled. Returns: The updated context object. """ - # claimed_auth_event_map should be given iff the event is an outlier - assert bool(claimed_auth_event_map) == event.internal_metadata.outlier + # This method should only be used for non-outliers + assert not event.internal_metadata.outlier room_version = await self._store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - if claimed_auth_event_map: - # if we have a copy of the auth events from the event, use that as the - # basis for auth. - auth_events = claimed_auth_event_map - else: - # otherwise, we calculate what the auth events *should* be, and use that - prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self._event_auth_handler.compute_auth_events( - event, prev_state_ids, for_verification=True - ) - auth_events_x = await self._store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} + # calculate what the auth events *should* be, to use as a basis for auth. + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self._event_auth_handler.compute_auth_events( + event, prev_state_ids, for_verification=True + ) + auth_events_x = await self._store.get_events(auth_events_ids) + calculated_auth_event_map = { + (e.type, e.state_key): e for e in auth_events_x.values() + } try: ( context, auth_events_for_auth, ) = await self._update_auth_events_and_context_for_auth( - origin, event, context, auth_events + origin, + event, + context, + calculated_auth_event_map=calculated_auth_event_map, ) except Exception: # We don't really mind if the above fails, so lets not fail @@ -1289,7 +1279,7 @@ class FederationEventHandler: "Ignoring failure and continuing processing of event.", event.event_id, ) - auth_events_for_auth = auth_events + auth_events_for_auth = calculated_auth_event_map try: event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) @@ -1425,7 +1415,7 @@ class FederationEventHandler: origin: str, event: EventBase, context: EventContext, - input_auth_events: StateMap[EventBase], + calculated_auth_event_map: StateMap[EventBase], ) -> Tuple[EventContext, StateMap[EventBase]]: """Helper for _check_event_auth. See there for docs. @@ -1443,19 +1433,17 @@ class FederationEventHandler: event: context: - input_auth_events: - Map from (event_type, state_key) to event - - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. + calculated_auth_event_map: + Our calculated auth_events based on the state of the room + at the event's position in the DAG. Returns: updated context, updated auth event map """ - # take a copy of input_auth_events before we modify it. - auth_events: MutableStateMap[EventBase] = dict(input_auth_events) + assert not event.internal_metadata.outlier + + # take a copy of calculated_auth_event_map before we modify it. + auth_events: MutableStateMap[EventBase] = dict(calculated_auth_event_map) event_auth_events = set(event.auth_event_ids()) @@ -1496,15 +1484,6 @@ class FederationEventHandler: } ) - if event.internal_metadata.is_outlier(): - # XXX: given that, for an outlier, we'll be working with the - # event's *claimed* auth events rather than those we calculated: - # (a) is there any point in this test, since different_auth below will - # obviously be empty - # (b) alternatively, why don't we do it earlier? - logger.info("Skipping auth_event fetch for outlier") - return context, auth_events - different_auth = event_auth_events.difference( e.event_id for e in auth_events.values() ) diff --git a/tests/test_federation.py b/tests/test_federation.py index c51e018da1..24fc77d7a7 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -82,7 +82,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase): event, context, state=None, - claimed_auth_event_map=None, backfilled=False, ): return context -- cgit 1.5.1 From 9fd057b8c5a8c5748e7d8137d1485c38abd9602f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 28 Sep 2021 21:23:16 -0500 Subject: Ensure `(room_id, next_batch_id)` is unique to avoid cross-talk/conflicts between batches (MSC2716) (#10877) Part of [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) Part of https://github.com/matrix-org/synapse/issues/10737 --- changelog.d/10877.feature | 1 + synapse/handlers/message.py | 34 ++++++++++++++++++++++++++++ synapse/rest/client/room_batch.py | 6 +++-- synapse/storage/databases/main/room_batch.py | 6 +++-- 4 files changed, 43 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10877.feature (limited to 'synapse/handlers') diff --git a/changelog.d/10877.feature b/changelog.d/10877.feature new file mode 100644 index 0000000000..06a246c108 --- /dev/null +++ b/changelog.d/10877.feature @@ -0,0 +1 @@ +Ensure `(room_id, next_batch_id)` is unique across [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) insertion events in rooms to avoid cross-talk/conflicts between batches. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c66aefe2c4..07aadf3f3c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,6 +16,7 @@ # limitations under the License. import logging import random +from http import HTTPStatus from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple from canonicaljson import encode_canonical_json @@ -1461,6 +1462,39 @@ class EventCreationHandler: if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") + if event.type == EventTypes.MSC2716_INSERTION: + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + create_event = await self.store.get_create_event_for_room(event.room_id) + room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) + + # Only check an insertion event if the room version + # supports it or the event is from the room creator. + if room_version_obj.msc2716_historical or ( + self.config.experimental.msc2716_enabled + and event.sender == room_creator + ): + next_batch_id = event.content.get( + EventContentFields.MSC2716_NEXT_BATCH_ID + ) + conflicting_insertion_event_id = ( + await self.store.get_insertion_event_by_batch_id( + event.room_id, next_batch_id + ) + ) + if conflicting_insertion_event_id is not None: + # The current insertion event that we're processing is invalid + # because an insertion event already exists in the room with the + # same next_batch_id. We can't allow multiple because the batch + # pointing will get weird, e.g. we can't determine which insertion + # event the batch event is pointing to. + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Another insertion event already exists with the same next_batch_id", + errcode=Codes.INVALID_PARAM, + ) + # Mark any `m.historical` messages as backfilled so they don't appear # in `/sync` and have the proper decrementing `stream_ordering` as we import backfilled = False diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index bf14ec384e..1dffcc3147 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -306,11 +306,13 @@ class RoomBatchSendEventRestServlet(RestServlet): # Verify the batch_id_from_query corresponds to an actual insertion event # and have the batch connected. corresponding_insertion_event_id = ( - await self.store.get_insertion_event_by_batch_id(batch_id_from_query) + await self.store.get_insertion_event_by_batch_id( + room_id, batch_id_from_query + ) ) if corresponding_insertion_event_id is None: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "No insertion event corresponds to the given ?batch_id", errcode=Codes.INVALID_PARAM, ) diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py index a383388757..300a563c9e 100644 --- a/synapse/storage/databases/main/room_batch.py +++ b/synapse/storage/databases/main/room_batch.py @@ -18,7 +18,9 @@ from synapse.storage._base import SQLBaseStore class RoomBatchStore(SQLBaseStore): - async def get_insertion_event_by_batch_id(self, batch_id: str) -> Optional[str]: + async def get_insertion_event_by_batch_id( + self, room_id: str, batch_id: str + ) -> Optional[str]: """Retrieve a insertion event ID. Args: @@ -30,7 +32,7 @@ class RoomBatchStore(SQLBaseStore): """ return await self.db_pool.simple_select_one_onecol( table="insertion_events", - keyvalues={"next_batch_id": batch_id}, + keyvalues={"room_id": room_id, "next_batch_id": batch_id}, retcol="event_id", allow_none=True, ) -- cgit 1.5.1 From 5279b9161b323cccdb74dcdf1a68fa7e19f091d4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 29 Sep 2021 10:57:10 +0100 Subject: Use `RoomVersion` objects (#10934) Various refactors to use `RoomVersion` objects instead of room version identifiers. --- changelog.d/10934.misc | 1 + synapse/events/builder.py | 20 ------------------ synapse/handlers/federation.py | 46 ++++++++++++++++++++++++------------------ synapse/handlers/message.py | 27 +++++++++++++++++++------ synapse/handlers/room.py | 4 ++-- 5 files changed, 50 insertions(+), 48 deletions(-) create mode 100644 changelog.d/10934.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10934.misc b/changelog.d/10934.misc new file mode 100644 index 0000000000..56c640ec9e --- /dev/null +++ b/changelog.d/10934.misc @@ -0,0 +1 @@ +Refactor various parts of the codebase to use `RoomVersion` objects instead of room version identifier strings. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 87e2bb123b..50f2a4c1f4 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -18,10 +18,8 @@ import attr from nacl.signing import SigningKey from synapse.api.constants import MAX_DEPTH -from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import ( KNOWN_EVENT_FORMAT_VERSIONS, - KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersion, ) @@ -197,24 +195,6 @@ class EventBuilderFactory: self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() - def new(self, room_version: str, key_values: dict) -> EventBuilder: - """Generate an event builder appropriate for the given room version - - Deprecated: use for_room_version with a RoomVersion object instead - - Args: - room_version: Version of the room that we're creating an event builder for - key_values: Fields used as the basis of the new event - - Returns: - EventBuilder - """ - v = KNOWN_ROOM_VERSIONS.get(room_version) - if not v: - # this can happen if support is withdrawn for a room version - raise UnsupportedRoomVersionError() - return self.for_room_version(v, key_values) - def for_room_version( self, room_version: RoomVersion, key_values: dict ) -> EventBuilder: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b17ef2a9a1..16c435ee86 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -718,8 +718,8 @@ class FederationHandler(BaseHandler): state_ids, ) - builder = self.event_builder_factory.new( - room_version.identifier, + builder = self.event_builder_factory.for_room_version( + room_version, { "type": EventTypes.Member, "content": event_content, @@ -897,9 +897,9 @@ class FederationHandler(BaseHandler): ) raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - room_version = await self.store.get_room_version_id(room_id) - builder = self.event_builder_factory.new( - room_version, + room_version_obj = await self.store.get_room_version(room_id) + builder = self.event_builder_factory.for_room_version( + room_version_obj, { "type": EventTypes.Member, "content": {"membership": Membership.LEAVE}, @@ -917,7 +917,7 @@ class FederationHandler(BaseHandler): # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` await self._event_auth_handler.check_from_context( - room_version, event, context, do_sig_check=False + room_version_obj.identifier, event, context, do_sig_check=False ) except AuthError as e: logger.warning("Failed to create new leave %r because %s", event, e) @@ -949,10 +949,10 @@ class FederationHandler(BaseHandler): ) raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - room_version = await self.store.get_room_version_id(room_id) + room_version_obj = await self.store.get_room_version(room_id) - builder = self.event_builder_factory.new( - room_version, + builder = self.event_builder_factory.for_room_version( + room_version_obj, { "type": EventTypes.Member, "content": {"membership": Membership.KNOCK}, @@ -979,7 +979,7 @@ class FederationHandler(BaseHandler): # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_knock_request` await self._event_auth_handler.check_from_context( - room_version, event, context, do_sig_check=False + room_version_obj.identifier, event, context, do_sig_check=False ) except AuthError as e: logger.warning("Failed to create new knock %r because %s", event, e) @@ -1245,8 +1245,10 @@ class FederationHandler(BaseHandler): } if await self._event_auth_handler.check_host_in_room(room_id, self.hs.hostname): - room_version = await self.store.get_room_version_id(room_id) - builder = self.event_builder_factory.new(room_version, event_dict) + room_version_obj = await self.store.get_room_version(room_id) + builder = self.event_builder_factory.for_room_version( + room_version_obj, event_dict + ) EventValidator().validate_builder(builder) event, context = await self.event_creation_handler.create_new_client_event( @@ -1254,7 +1256,7 @@ class FederationHandler(BaseHandler): ) event, context = await self.add_display_name_to_third_party_invite( - room_version, event_dict, event, context + room_version_obj, event_dict, event, context ) EventValidator().validate_new(event, self.config) @@ -1265,7 +1267,7 @@ class FederationHandler(BaseHandler): try: await self._event_auth_handler.check_from_context( - room_version, event, context + room_version_obj.identifier, event, context ) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) @@ -1299,22 +1301,24 @@ class FederationHandler(BaseHandler): """ assert_params_in_dict(event_dict, ["room_id"]) - room_version = await self.store.get_room_version_id(event_dict["room_id"]) + room_version_obj = await self.store.get_room_version(event_dict["room_id"]) # NB: event_dict has a particular specced format we might need to fudge # if we change event formats too much. - builder = self.event_builder_factory.new(room_version, event_dict) + builder = self.event_builder_factory.for_room_version( + room_version_obj, event_dict + ) event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) event, context = await self.add_display_name_to_third_party_invite( - room_version, event_dict, event, context + room_version_obj, event_dict, event, context ) try: await self._event_auth_handler.check_from_context( - room_version, event, context + room_version_obj.identifier, event, context ) except AuthError as e: logger.warning("Denying third party invite %r because %s", event, e) @@ -1331,7 +1335,7 @@ class FederationHandler(BaseHandler): async def add_display_name_to_third_party_invite( self, - room_version: str, + room_version_obj: RoomVersion, event_dict: JsonDict, event: EventBase, context: EventContext, @@ -1363,7 +1367,9 @@ class FederationHandler(BaseHandler): # auth checks. If we need the invite and don't have it then the # auth check code will explode appropriately. - builder = self.event_builder_factory.new(room_version, event_dict) + builder = self.event_builder_factory.for_room_version( + room_version_obj, event_dict + ) EventValidator().validate_builder(builder) event, context = await self.event_creation_handler.create_new_client_event( builder=builder diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 07aadf3f3c..39c18ecf99 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -40,6 +40,7 @@ from synapse.api.errors import ( NotFoundError, ShadowBanError, SynapseError, + UnsupportedRoomVersionError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder @@ -550,16 +551,22 @@ class EventCreationHandler: await self.auth.check_auth_blocking(requester=requester) if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": - room_version = event_dict["content"]["room_version"] + room_version_id = event_dict["content"]["room_version"] + room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version_obj: + # this can happen if support is withdrawn for a room version + raise UnsupportedRoomVersionError(room_version_id) else: try: - room_version = await self.store.get_room_version_id( + room_version_obj = await self.store.get_room_version( event_dict["room_id"] ) except NotFoundError: raise AuthError(403, "Unknown room") - builder = self.event_builder_factory.new(room_version, event_dict) + builder = self.event_builder_factory.for_room_version( + room_version_obj, event_dict + ) self.validator.validate_builder(builder) @@ -1070,9 +1077,17 @@ class EventCreationHandler: EventTypes.Create, "", ): - room_version = event.content.get("room_version", RoomVersions.V1.identifier) + room_version_id = event.content.get( + "room_version", RoomVersions.V1.identifier + ) + room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version_obj: + raise UnsupportedRoomVersionError( + "Attempt to create a room with unsupported room version %s" + % (room_version_id,) + ) else: - room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = await self.store.get_room_version(event.room_id) if event.internal_metadata.is_out_of_band_membership(): # the only sort of out-of-band-membership events we expect to see here are @@ -1082,7 +1097,7 @@ class EventCreationHandler: else: try: await self._event_auth_handler.check_from_context( - room_version, event, context + room_version_obj.identifier, event, context ) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 8fede5e935..dc4fab2223 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -237,9 +237,9 @@ class RoomCreationHandler(BaseHandler): }, }, ) - old_room_version = await self.store.get_room_version_id(old_room_id) + old_room_version = await self.store.get_room_version(old_room_id) await self._event_auth_handler.check_from_context( - old_room_version, tombstone_event, tombstone_context + old_room_version.identifier, tombstone_event, tombstone_context ) await self.clone_existing_room( -- cgit 1.5.1 From 94b620a5edd6b5bc55c8aad6e00a11cc6bf210fa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 29 Sep 2021 06:44:15 -0400 Subject: Use direct references for configuration variables (part 6). (#10916) --- changelog.d/10916.misc | 1 + synapse/app/_base.py | 8 ++++---- synapse/app/admin_cmd.py | 4 ++-- synapse/app/generic_worker.py | 2 +- synapse/app/homeserver.py | 14 +++++++------- synapse/app/phone_stats_home.py | 8 ++++---- synapse/config/_base.py | 2 +- synapse/config/server.py | 4 +--- synapse/events/presence_router.py | 6 +++--- synapse/events/utils.py | 2 +- synapse/federation/transport/server/__init__.py | 2 +- synapse/handlers/directory.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/identity.py | 2 +- synapse/handlers/message.py | 14 ++++++++------ synapse/handlers/pagination.py | 14 ++++++++++---- synapse/handlers/profile.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_member.py | 14 +++++++------- synapse/handlers/search.py | 2 +- synapse/handlers/user_directory.py | 2 +- synapse/http/matrixfederationclient.py | 10 +++++----- synapse/replication/tcp/resource.py | 2 +- synapse/rest/client/account.py | 10 +++++----- synapse/rest/client/capabilities.py | 4 ++-- synapse/rest/client/filter.py | 2 +- synapse/rest/client/profile.py | 6 +++--- synapse/rest/client/register.py | 6 +++--- synapse/rest/client/room.py | 2 +- synapse/rest/client/shared_rooms.py | 2 +- synapse/rest/client/sync.py | 2 +- synapse/server_notices/resource_limits_server_notices.py | 8 ++++---- synapse/storage/databases/main/censor_events.py | 8 +++++--- synapse/storage/databases/main/client_ips.py | 2 +- synapse/storage/databases/main/events.py | 2 +- synapse/storage/databases/main/monthly_active_users.py | 12 ++++++------ synapse/storage/databases/main/registration.py | 2 +- synapse/storage/databases/main/room.py | 8 ++++---- synapse/storage/databases/main/search.py | 4 ++-- synapse/storage/prepare_database.py | 2 +- tests/api/test_auth.py | 14 +++++++------- tests/federation/test_federation_server.py | 2 +- tests/handlers/test_register.py | 14 +++++++------- tests/http/test_fedclient.py | 2 +- tests/rest/admin/test_user.py | 6 +++--- tests/rest/client/test_account.py | 2 +- tests/rest/client/test_capabilities.py | 2 +- tests/rest/client/test_presence.py | 2 +- tests/rest/client/test_register.py | 4 ++-- .../server_notices/test_resource_limits_server_notices.py | 2 +- tests/storage/test_monthly_active_users.py | 14 +++++++------- tests/test_mau.py | 2 +- tests/unittest.py | 2 +- 54 files changed, 141 insertions(+), 132 deletions(-) create mode 100644 changelog.d/10916.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10916.misc b/changelog.d/10916.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10916.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 548f6dcde9..749bc1deb9 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -86,11 +86,11 @@ def start_worker_reactor(appname, config, run_command=reactor.run): start_reactor( appname, - soft_file_limit=config.soft_file_limit, - gc_thresholds=config.gc_thresholds, + soft_file_limit=config.server.soft_file_limit, + gc_thresholds=config.server.gc_thresholds, pid_file=config.worker.worker_pid_file, daemonize=config.worker.worker_daemonize, - print_pidfile=config.print_pidfile, + print_pidfile=config.server.print_pidfile, logger=logger, run_command=run_command, ) @@ -298,7 +298,7 @@ def refresh_certificate(hs): Refresh the TLS certificates that Synapse is using by re-reading them from disk and updating the TLS context factories to use them. """ - if not hs.config.has_tls_listener(): + if not hs.config.server.has_tls_listener(): return hs.config.read_certificate_from_disk() diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index f2c5b75247..556bcc124e 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -195,14 +195,14 @@ def start(config_options): config.logging.no_redirect_stdio = True # Explicitly disable background processes - config.update_user_directory = False + config.server.update_user_directory = False config.worker.run_background_tasks = False config.start_pushers = False config.pusher_shard_config.instances = [] config.send_federation = False config.federation_shard_config.instances = [] - synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts + synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts ss = AdminCmdServer( config.server.server_name, diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 3036e1b4a0..7489f31d9a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -462,7 +462,7 @@ def start(config_options): # For other worker types we force this to off. config.server.update_user_directory = False - synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts + synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage if config.server.gc_seconds: diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 205831dcda..2b2d4bbf83 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -248,7 +248,7 @@ class SynapseHomeServer(HomeServer): resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) if name == "webclient": - webclient_loc = self.config.web_client_location + webclient_loc = self.config.server.web_client_location if webclient_loc is None: logger.warning( @@ -343,7 +343,7 @@ def setup(config_options): # generating config files and shouldn't try to continue. sys.exit(0) - events.USE_FROZEN_DICTS = config.use_frozen_dicts + events.USE_FROZEN_DICTS = config.server.use_frozen_dicts synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage if config.server.gc_seconds: @@ -439,11 +439,11 @@ def run(hs): _base.start_reactor( "synapse-homeserver", - soft_file_limit=hs.config.soft_file_limit, - gc_thresholds=hs.config.gc_thresholds, - pid_file=hs.config.pid_file, - daemonize=hs.config.daemonize, - print_pidfile=hs.config.print_pidfile, + soft_file_limit=hs.config.server.soft_file_limit, + gc_thresholds=hs.config.server.gc_thresholds, + pid_file=hs.config.server.pid_file, + daemonize=hs.config.server.daemonize, + print_pidfile=hs.config.server.print_pidfile, logger=logger, ) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 49e7a45e5c..fcd01e833c 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -74,7 +74,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): store = hs.get_datastore() stats["homeserver"] = hs.config.server.server_name - stats["server_context"] = hs.config.server_context + stats["server_context"] = hs.config.server.server_context stats["timestamp"] = now stats["uptime_seconds"] = uptime version = sys.version_info @@ -171,7 +171,7 @@ def start_phone_stats_home(hs): current_mau_count_by_service = {} reserved_users = () store = hs.get_datastore() - if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: + if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: current_mau_count = await store.get_monthly_active_count() current_mau_count_by_service = ( await store.get_monthly_active_count_by_service() @@ -183,9 +183,9 @@ def start_phone_stats_home(hs): current_mau_by_service_gauge.labels(app_service).set(float(count)) registered_reserved_users_mau_gauge.set(float(len(reserved_users))) - max_mau_gauge.set(float(hs.config.max_mau_value)) + max_mau_gauge.set(float(hs.config.server.max_mau_value)) - if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: + if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: generate_monthly_active_users() clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) # End of monthly active user settings diff --git a/synapse/config/_base.py b/synapse/config/_base.py index d974a1a2a8..26152b0924 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -327,7 +327,7 @@ class RootConfig: """ Redirect lookups on this object either to config objects, or values on config objects, so that `config.tls.blah` works, as well as legacy uses - of things like `config.server_name`. It will first look up the config + of things like `config.server.server_name`. It will first look up the config section name, and then values on those config classes. """ if item in self._configs.keys(): diff --git a/synapse/config/server.py b/synapse/config/server.py index 041412d7ad..818b806357 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1,6 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py index eb4556cdc1..68b8b19024 100644 --- a/synapse/events/presence_router.py +++ b/synapse/events/presence_router.py @@ -45,11 +45,11 @@ def load_legacy_presence_router(hs: "HomeServer"): configuration, and registers the hooks they implement. """ - if hs.config.presence_router_module_class is None: + if hs.config.server.presence_router_module_class is None: return - module = hs.config.presence_router_module_class - config = hs.config.presence_router_config + module = hs.config.server.presence_router_module_class + config = hs.config.server.presence_router_config api = hs.get_module_api() presence_router = module(config=config, module_api=api) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index f86113a448..a13fb0148f 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -372,7 +372,7 @@ class EventClientSerializer: def __init__(self, hs): self.store = hs.get_datastore() self.experimental_msc1849_support_enabled = ( - hs.config.experimental_msc1849_support_enabled + hs.config.server.experimental_msc1849_support_enabled ) async def serialize_event( diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 95176ba6f9..c32539bf5a 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -117,7 +117,7 @@ class PublicRoomList(BaseFederationServlet): ): super().__init__(hs, authenticator, ratelimiter, server_name) self.handler = hs.get_room_list_handler() - self.allow_access = hs.config.allow_public_rooms_over_federation + self.allow_access = hs.config.server.allow_public_rooms_over_federation async def on_GET( self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]] diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 5cfba3c817..9078781d5a 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -49,7 +49,7 @@ class DirectoryHandler(BaseHandler): self.store = hs.get_datastore() self.config = hs.config self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search - self.require_membership = hs.config.require_membership_for_aliases + self.require_membership = hs.config.server.require_membership_for_aliases self.third_party_event_rules = hs.get_third_party_event_rules() self.federation = hs.get_federation_client() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 16c435ee86..3b0b895b07 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -762,7 +762,7 @@ class FederationHandler(BaseHandler): if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - if self.hs.config.block_non_admin_invites: + if self.hs.config.server.block_non_admin_invites: raise SynapseError(403, "This server does not accept room invites") if not await self.spam_checker.user_may_invite( diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index fe8a995892..a0640fcac0 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -57,7 +57,7 @@ class IdentityHandler(BaseHandler): self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. self.blacklisting_http_client = SimpleHttpClient( - hs, ip_blacklist=hs.config.federation_ip_range_blacklist + hs, ip_blacklist=hs.config.server.federation_ip_range_blacklist ) self.federation_http_client = hs.get_federation_http_client() self.hs = hs diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 39c18ecf99..3b8cc50ec0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -81,7 +81,7 @@ class MessageHandler: self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() - self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages # The scheduled call to self._expire_event. None if no call is currently # scheduled. @@ -415,7 +415,9 @@ class EventCreationHandler: self.server_name = hs.hostname self.notifier = hs.get_notifier() self.config = hs.config - self.require_membership_for_aliases = hs.config.require_membership_for_aliases + self.require_membership_for_aliases = ( + hs.config.server.require_membership_for_aliases + ) self._events_shard_config = self.config.worker.events_shard_config self._instance_name = hs.get_instance_name() @@ -425,7 +427,7 @@ class EventCreationHandler: Membership.JOIN, Membership.KNOCK, } - if self.hs.config.include_profile_data_on_invite: + if self.hs.config.server.include_profile_data_on_invite: self.membership_types_to_include_profile_data_in.add(Membership.INVITE) self.send_event = ReplicationSendEventRestServlet.make_client(hs) @@ -461,11 +463,11 @@ class EventCreationHandler: # self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {} # The number of forward extremeities before a dummy event is sent. - self._dummy_events_threshold = hs.config.dummy_events_threshold + self._dummy_events_threshold = hs.config.server.dummy_events_threshold if ( self.config.worker.run_background_tasks - and self.config.cleanup_extremities_with_dummy_events + and self.config.server.cleanup_extremities_with_dummy_events ): self.clock.looping_call( lambda: run_as_background_process( @@ -477,7 +479,7 @@ class EventCreationHandler: self._message_handler = hs.get_message_handler() - self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages self._external_cache = hs.get_external_cache() diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a5301ece6f..176e4dfdd4 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -85,12 +85,18 @@ class PaginationHandler: self._purges_by_id: Dict[str, PurgeStatus] = {} self._event_serializer = hs.get_event_client_serializer() - self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime + self._retention_default_max_lifetime = ( + hs.config.server.retention_default_max_lifetime + ) - self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min - self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max + self._retention_allowed_lifetime_min = ( + hs.config.server.retention_allowed_lifetime_min + ) + self._retention_allowed_lifetime_max = ( + hs.config.server.retention_allowed_lifetime_max + ) - if hs.config.worker.run_background_tasks and hs.config.retention_enabled: + if hs.config.worker.run_background_tasks and hs.config.server.retention_enabled: # Run the purge jobs described in the configuration file. for job in hs.config.server.retention_purge_jobs: logger.info("Setting up purge job with config: %s", job) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b23a1541bc..425c0d4973 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -397,7 +397,7 @@ class ProfileHandler(BaseHandler): # when building a membership event. In this case, we must allow the # lookup. if ( - not self.hs.config.limit_profile_requests_to_users_who_share_rooms + not self.hs.config.server.limit_profile_requests_to_users_who_share_rooms or not requester ): return diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 4f99f137a2..4a7ccb882e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -854,7 +854,7 @@ class RegistrationHandler(BaseHandler): # Necessary due to auth checks prior to the threepid being # written to the db if is_threepid_reserved( - self.hs.config.mau_limits_reserved_threepids, threepid + self.hs.config.server.mau_limits_reserved_threepids, threepid ): await self.store.upsert_monthly_active_user(user_id) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index dc4fab2223..bf8a85f563 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -666,7 +666,7 @@ class RoomCreationHandler(BaseHandler): await self.ratelimit(requester) room_version_id = config.get( - "room_version", self.config.default_room_version.identifier + "room_version", self.config.server.default_room_version.identifier ) if not isinstance(room_version_id, str): diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 1a56c82fbd..02103f6c9a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -90,7 +90,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.third_party_event_rules = hs.get_third_party_event_rules() self._server_notices_mxid = self.config.servernotices.server_notices_mxid self._enable_lookup = hs.config.enable_3pid_lookup - self.allow_per_room_profiles = self.config.allow_per_room_profiles + self.allow_per_room_profiles = self.config.server.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( store=self.store, @@ -617,7 +617,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: - if self.config.block_non_admin_invites: + if self.config.server.block_non_admin_invites: logger.info( "Blocking invite: user is not admin and non-admin " "invites disabled" @@ -1222,7 +1222,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): Raises: ShadowBanError if the requester has been shadow-banned. """ - if self.config.block_non_admin_invites: + if self.config.server.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: raise SynapseError( @@ -1420,7 +1420,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): Returns: bool of whether the complexity is too great, or None if unable to be fetched """ - max_complexity = self.hs.config.limit_remote_rooms.complexity + max_complexity = self.hs.config.server.limit_remote_rooms.complexity complexity = await self.federation_handler.get_room_complexity( remote_room_hosts, room_id ) @@ -1436,7 +1436,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): Args: room_id: The room ID to check for complexity. """ - max_complexity = self.hs.config.limit_remote_rooms.complexity + max_complexity = self.hs.config.server.limit_remote_rooms.complexity complexity = await self.store.get_room_complexity(room_id) return complexity["v1"] > max_complexity @@ -1472,7 +1472,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): if too_complex is True: raise SynapseError( code=400, - msg=self.hs.config.limit_remote_rooms.complexity_error, + msg=self.hs.config.server.limit_remote_rooms.complexity_error, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) @@ -1507,7 +1507,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) raise SynapseError( code=400, - msg=self.hs.config.limit_remote_rooms.complexity_error, + msg=self.hs.config.server.limit_remote_rooms.complexity_error, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 8226d6f5a1..6d3333ee00 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -105,7 +105,7 @@ class SearchHandler(BaseHandler): dict to be returned to the client with results of search """ - if not self.hs.config.enable_search: + if not self.hs.config.server.enable_search: raise SynapseError(400, "Search is disabled on this homeserver") batch_group = None diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index b91e7cb501..f4430ce3c9 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -60,7 +60,7 @@ class UserDirectoryHandler(StateDeltasHandler): self.clock = hs.get_clock() self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id - self.update_user_directory = hs.config.update_user_directory + self.update_user_directory = hs.config.server.update_user_directory self.search_all_users = hs.config.userdirectory.user_directory_search_all_users self.spam_checker = hs.get_spam_checker() # The current position in the current_state_delta stream diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index cdc36b8d25..4f59224686 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -327,23 +327,23 @@ class MatrixFederationHttpClient: self.reactor = hs.get_reactor() user_agent = hs.version_string - if hs.config.user_agent_suffix: - user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix) + if hs.config.server.user_agent_suffix: + user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix) user_agent = user_agent.encode("ascii") federation_agent = MatrixFederationAgent( self.reactor, tls_client_options_factory, user_agent, - hs.config.federation_ip_range_whitelist, - hs.config.federation_ip_range_blacklist, + hs.config.server.federation_ip_range_whitelist, + hs.config.server.federation_ip_range_blacklist, ) # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( federation_agent, - ip_blacklist=hs.config.federation_ip_range_blacklist, + ip_blacklist=hs.config.server.federation_ip_range_blacklist, ) self.clock = hs.get_clock() diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 030852cb5b..80f9b23bfd 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -71,7 +71,7 @@ class ReplicationStreamer: self.notifier = hs.get_notifier() self._instance_name = hs.get_instance_name() - self._replication_torture_level = hs.config.replication_torture_level + self._replication_torture_level = hs.config.server.replication_torture_level self.notifier.add_replication_callback(self.on_notifier_poke) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index bacb828330..fff133ef10 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -119,7 +119,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) if existing_user_id is None: - if self.config.request_token_inhibit_3pid_errors: + if self.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -403,7 +403,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): existing_user_id = await self.store.get_user_id_by_threepid("email", email) if existing_user_id is not None: - if self.config.request_token_inhibit_3pid_errors: + if self.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -486,7 +486,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: + if self.hs.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -857,8 +857,8 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None: # If the domain whitelist is set, the domain must be in it if ( valid - and hs.config.next_link_domain_whitelist is not None - and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist + and hs.config.server.next_link_domain_whitelist is not None + and next_link_parsed.hostname not in hs.config.server.next_link_domain_whitelist ): valid = False diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 65b3b5ce2c..d6b6256413 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -44,10 +44,10 @@ class CapabilitiesRestServlet(RestServlet): await self.auth.get_user_by_req(request, allow_guest=True) change_password = self.auth_handler.can_change_password() - response = { + response: JsonDict = { "capabilities": { "m.room_versions": { - "default": self.config.default_room_version.identifier, + "default": self.config.server.default_room_version.identifier, "available": { v.identifier: v.disposition for v in KNOWN_ROOM_VERSIONS.values() diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index 6ed60c7418..cc1c2f9731 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -90,7 +90,7 @@ class CreateFilterRestServlet(RestServlet): raise AuthError(403, "Can only create filters for local users") content = parse_json_object_from_request(request) - set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) + set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit) filter_id = await self.filtering.add_user_filter( user_localpart=target_user.localpart, user_filter=content diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index d0f20de569..c684636c0a 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -41,7 +41,7 @@ class ProfileDisplaynameRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester_user = None - if self.hs.config.require_auth_for_profile_requests: + if self.hs.config.server.require_auth_for_profile_requests: requester = await self.auth.get_user_by_req(request) requester_user = requester.user @@ -94,7 +94,7 @@ class ProfileAvatarURLRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester_user = None - if self.hs.config.require_auth_for_profile_requests: + if self.hs.config.server.require_auth_for_profile_requests: requester = await self.auth.get_user_by_req(request) requester_user = requester.user @@ -146,7 +146,7 @@ class ProfileRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester_user = None - if self.hs.config.require_auth_for_profile_requests: + if self.hs.config.server.require_auth_for_profile_requests: requester = await self.auth.get_user_by_req(request) requester_user = requester.user diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 48b0062cf4..a6eb6f6410 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -129,7 +129,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: + if self.hs.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -209,7 +209,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: + if self.hs.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -682,7 +682,7 @@ class RegisterRestServlet(RestServlet): # written to the db if threepid: if is_threepid_reserved( - self.hs.config.mau_limits_reserved_threepids, threepid + self.hs.config.server.mau_limits_reserved_threepids, threepid ): await self.store.upsert_monthly_active_user(registered_user_id) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index bf46dc60f2..ed95189b6d 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -369,7 +369,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private # federations. - if not self.hs.config.allow_public_rooms_without_auth: + if not self.hs.config.server.allow_public_rooms_without_auth: raise # We allow people to not be authed if they're just looking at our diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py index 1d90493eb0..09a46737de 100644 --- a/synapse/rest/client/shared_rooms.py +++ b/synapse/rest/client/shared_rooms.py @@ -42,7 +42,7 @@ class UserSharedRoomsServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - self.user_directory_active = hs.config.update_user_directory + self.user_directory_active = hs.config.server.update_user_directory async def on_GET( self, request: SynapseRequest, user_id: str diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 1259058b9b..913216a7c4 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -155,7 +155,7 @@ class SyncRestServlet(RestServlet): try: filter_object = json_decoder.decode(filter_id) set_timeline_upper_limit( - filter_object, self.hs.config.filter_timeline_limit + filter_object, self.hs.config.server.filter_timeline_limit ) except Exception: raise SynapseError(400, "Invalid filter JSON") diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 073b0d754f..8522930b50 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -47,9 +47,9 @@ class ResourceLimitsServerNotices: self._notifier = hs.get_notifier() self._enabled = ( - hs.config.limit_usage_by_mau + hs.config.server.limit_usage_by_mau and self._server_notices_manager.is_enabled() - and not hs.config.hs_disabled + and not hs.config.server.hs_disabled ) async def maybe_send_server_notice_to_user(self, user_id: str) -> None: @@ -98,7 +98,7 @@ class ResourceLimitsServerNotices: try: if ( limit_type == LimitBlockingTypes.MONTHLY_ACTIVE_USER - and not self._config.mau_limit_alerting + and not self._config.server.mau_limit_alerting ): # We have hit the MAU limit, but MAU alerting is disabled: # reset room if necessary and return @@ -149,7 +149,7 @@ class ResourceLimitsServerNotices: "body": event_body, "msgtype": ServerNoticeMsgType, "server_notice_type": ServerNoticeLimitReached, - "admin_contact": self._config.admin_contact, + "admin_contact": self._config.server.admin_contact, "limit_type": event_limit_type, } event = await self._server_notices_manager.send_notice( diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 6305414e3d..eee07227ef 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -36,7 +36,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase if ( hs.config.worker.run_background_tasks - and self.hs.config.redaction_retention_period is not None + and self.hs.config.server.redaction_retention_period is not None ): hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000) @@ -48,7 +48,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase By censor we mean update the event_json table with the redacted event. """ - if self.hs.config.redaction_retention_period is None: + if self.hs.config.server.redaction_retention_period is None: return if not ( @@ -60,7 +60,9 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase # created. return - before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period + before_ts = ( + self._clock.time_msec() - self.hs.config.server.redaction_retention_period + ) # We fetch all redactions that: # 1. point to an event we have, diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 7e33ae578c..0e1d97aaeb 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -353,7 +353,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.user_ips_max_age = hs.config.user_ips_max_age + self.user_ips_max_age = hs.config.server.user_ips_max_age if hs.config.worker.run_background_tasks and self.user_ips_max_age: self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index cc4e31ec30..bc7d213fe2 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -104,7 +104,7 @@ class PersistEventsStore: self._clock = hs.get_clock() self._instance_name = hs.get_instance_name() - self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id # Ideally we'd move these ID gens here, unfortunately some other ID diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index b76ee51a9b..a14ac03d4b 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -32,8 +32,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs - self._limit_usage_by_mau = hs.config.limit_usage_by_mau - self._max_mau_value = hs.config.max_mau_value + self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau + self._max_mau_value = hs.config.server.max_mau_value @cached(num_args=0) async def get_monthly_active_count(self) -> int: @@ -96,8 +96,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ users = [] - for tp in self.hs.config.mau_limits_reserved_threepids[ - : self.hs.config.max_mau_value + for tp in self.hs.config.server.mau_limits_reserved_threepids[ + : self.hs.config.server.max_mau_value ]: user_id = await self.hs.get_datastore().get_user_id_by_threepid( tp["medium"], tp["address"] @@ -212,7 +212,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._mau_stats_only = hs.config.mau_stats_only + self._mau_stats_only = hs.config.server.mau_stats_only # Do not add more reserved users than the total allowable number self.db_pool.new_transaction( @@ -221,7 +221,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): [], [], self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self._max_mau_value], + hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], ) def _initialise_reserved_users(self, txn, threepids): diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index c83089ee63..7279b0924e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -207,7 +207,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return False now = self._clock.time_msec() - trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 + trial_duration_ms = self.config.server.mau_trial_days * 24 * 60 * 60 * 1000 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms return is_trial diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 118b390e93..d69eaf80ce 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -679,8 +679,8 @@ class RoomWorkerStore(SQLBaseStore): # policy. if not ret: return { - "min_lifetime": self.config.retention_default_min_lifetime, - "max_lifetime": self.config.retention_default_max_lifetime, + "min_lifetime": self.config.server.retention_default_min_lifetime, + "max_lifetime": self.config.server.retention_default_max_lifetime, } row = ret[0] @@ -690,10 +690,10 @@ class RoomWorkerStore(SQLBaseStore): # The default values will be None if no default policy has been defined, or if one # of the attributes is missing from the default policy. if row["min_lifetime"] is None: - row["min_lifetime"] = self.config.retention_default_min_lifetime + row["min_lifetime"] = self.config.server.retention_default_min_lifetime if row["max_lifetime"] is None: - row["max_lifetime"] = self.config.retention_default_max_lifetime + row["max_lifetime"] = self.config.server.retention_default_max_lifetime return row diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 2a1e99e17a..c85383c975 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -51,7 +51,7 @@ class SearchWorkerStore(SQLBaseStore): txn: entries: entries to be added to the table """ - if not self.hs.config.enable_search: + if not self.hs.config.server.enable_search: return if isinstance(self.database_engine, PostgresEngine): sql = ( @@ -105,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - if not hs.config.enable_search: + if not hs.config.server.enable_search: return self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index f31880b8ec..a63eaddfdc 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -366,7 +366,7 @@ def _upgrade_existing_database( + "new for the server to understand" ) - # some of the deltas assume that config.server_name is set correctly, so now + # some of the deltas assume that server_name is set correctly, so now # is a good time to run the sanity check. if not is_empty and "main" in databases: from synapse.storage.databases.main import check_database_before_upgrade diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index cccff7af26..3aa9ba3c43 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -217,7 +217,7 @@ class AuthTestCase(unittest.HomeserverTestCase): user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, + location=self.hs.config.server.server_name, identifier="key", key=self.hs.config.key.macaroon_secret_key, ) @@ -239,7 +239,7 @@ class AuthTestCase(unittest.HomeserverTestCase): user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, + location=self.hs.config.server.server_name, identifier="key", key=self.hs.config.key.macaroon_secret_key, ) @@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_monthly_active_count = simple_async_mock(lots_of_users) e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403) @@ -303,7 +303,7 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] @@ -332,7 +332,7 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] @@ -372,7 +372,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403) @@ -387,7 +387,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 0b60cc4261..03e1e11f49 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -120,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): self.assertEqual( channel.json_body["room_version"], - self.hs.config.default_room_version.identifier, + self.hs.config.server.default_room_version.identifier, ) members = set( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index d3efb67e3e..bd05a2c2d1 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -175,20 +175,20 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(result_token is not None) def test_mau_limits_when_disabled(self): - self.hs.config.limit_usage_by_mau = False + self.hs.config.server.limit_usage_by_mau = False # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "a", "display_name")) def test_get_or_create_user_mau_not_blocked(self): - self.hs.config.limit_usage_by_mau = True + self.hs.config.server.limit_usage_by_mau = True self.store.count_monthly_users = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value - 1) + return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) def test_get_or_create_user_mau_blocked(self): - self.hs.config.limit_usage_by_mau = True + self.hs.config.server.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -198,7 +198,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -206,7 +206,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) def test_register_mau_blocked(self): - self.hs.config.limit_usage_by_mau = True + self.hs.config.server.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -215,7 +215,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index d9a8b077d3..638babae69 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -226,7 +226,7 @@ class FederationClientTests(HomeserverTestCase): """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist - self.hs.config.federation_ip_range_blacklist = IPSet( + self.hs.config.server.federation_ip_range_blacklist = IPSet( ["127.0.0.0/8", "fe80::/64"] ) self.reactor.lookups["internal"] = "127.0.0.1" diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ee3ae9cce4..a285d5a7fe 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -422,7 +422,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -1485,7 +1485,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -1522,7 +1522,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 64b0b8458b..2f44547bfb 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -516,7 +516,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": user_id, "exclusive": True}]}, sender=user_id, diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index 422361b62a..b9e3602552 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -55,7 +55,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) self.assertEqual( - self.config.default_room_version.identifier, + self.config.server.default_room_version.identifier, capabilities["m.room_versions"]["default"], ) diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 1d152352d1..56fe1a3d01 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -50,7 +50,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): PUT to the status endpoint with use_presence enabled will call set_state on the presence handler. """ - self.hs.config.use_presence = True + self.hs.config.server.use_presence = True body = {"presence": "here", "status_msg": "beep boop"} channel = self.make_request( diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 72a5a11b46..af135d57e1 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -50,7 +50,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", @@ -74,7 +74,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 7f25200a5d..36c495954f 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -346,7 +346,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): invites = [] # Register as many users as the MAU limit allows. - for i in range(self.hs.config.max_mau_value): + for i in range(self.hs.config.server.max_mau_value): localpart = "user%d" % i user_id = self.register_user(localpart, "password") tok = self.login(localpart, "password") diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 944dbc34a2..d6b4cdd788 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -51,7 +51,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) def test_initialise_reserved_users(self): - threepids = self.hs.config.mau_limits_reserved_threepids + threepids = self.hs.config.server.mau_limits_reserved_threepids # register three users, of which two have reserved 3pids, and a third # which is a support user. @@ -101,9 +101,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # XXX some of this is redundant. poking things into the config shouldn't # work, and in any case it's not obvious what we expect to happen when # we advance the reactor. - self.hs.config.max_mau_value = 0 + self.hs.config.server.max_mau_value = 0 self.reactor.advance(FORTY_DAYS) - self.hs.config.max_mau_value = 5 + self.hs.config.server.max_mau_value = 5 self.get_success(self.store.reap_monthly_active_users()) @@ -183,7 +183,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(d) count = self.get_success(self.store.get_monthly_active_count()) - self.assertEqual(count, self.hs.config.max_mau_value) + self.assertEqual(count, self.hs.config.server.max_mau_value) self.reactor.advance(FORTY_DAYS) @@ -199,7 +199,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_reap_monthly_active_users_reserved_users(self): """Tests that reaping correctly handles reaping where reserved users are present""" - threepids = self.hs.config.mau_limits_reserved_threepids + threepids = self.hs.config.server.mau_limits_reserved_threepids initial_users = len(threepids) reserved_user_number = initial_users - 1 for i in range(initial_users): @@ -234,7 +234,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(d) count = self.get_success(self.store.get_monthly_active_count()) - self.assertEqual(count, self.hs.config.max_mau_value) + self.assertEqual(count, self.hs.config.server.max_mau_value) def test_populate_monthly_users_is_guest(self): # Test that guest users are not added to mau list @@ -294,7 +294,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": user2_email}, ] - self.hs.config.mau_limits_reserved_threepids = threepids + self.hs.config.server.mau_limits_reserved_threepids = threepids d = self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) diff --git a/tests/test_mau.py b/tests/test_mau.py index 66111eb367..80ab40e255 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -165,7 +165,7 @@ class TestMauLimit(unittest.HomeserverTestCase): @override_config({"mau_trial_days": 1}) def test_trial_users_cant_come_back(self): - self.hs.config.mau_trial_days = 1 + self.hs.config.server.mau_trial_days = 1 # We should be able to register more than the limit initially token1 = self.create_user("kermit1") diff --git a/tests/unittest.py b/tests/unittest.py index 7a6f5954d0..6d5d87cb78 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -232,7 +232,7 @@ class HomeserverTestCase(TestCase): # Honour the `use_frozen_dicts` config option. We have to do this # manually because this is taken care of in the app `start` code, which # we don't run. Plus we want to reset it on tearDown. - events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts + events.USE_FROZEN_DICTS = self.hs.config.server.use_frozen_dicts if self.hs is None: raise Exception("No homeserver returned from make_homeserver.") -- cgit 1.5.1 From 428174f90249ec50f977b5ef5c5cf9f92599ee0a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 29 Sep 2021 18:59:15 +0100 Subject: Split `event_auth.check` into two parts (#10940) Broadly, the existing `event_auth.check` function has two parts: * a validation section: checks that the event isn't too big, that it has the rught signatures, etc. This bit is independent of the rest of the state in the room, and so need only be done once for each event. * an auth section: ensures that the event is allowed, given the rest of the state in the room. This gets done multiple times, against various sets of room state, because it forms part of the state res algorithm. Currently, this is implemented with `do_sig_check` and `do_size_check` parameters, but I think that makes everything hard to follow. Instead, we split the function in two and call each part separately where it is needed. --- changelog.d/10940.misc | 1 + synapse/event_auth.py | 153 +++++++++++++++++++++-------------- synapse/handlers/event_auth.py | 15 ++-- synapse/handlers/federation.py | 30 ++++--- synapse/handlers/federation_event.py | 18 +++-- synapse/handlers/message.py | 6 +- synapse/handlers/room.py | 6 +- synapse/state/v1.py | 8 +- synapse/state/v2.py | 4 +- tests/test_event_auth.py | 108 +++++++++---------------- 10 files changed, 177 insertions(+), 172 deletions(-) create mode 100644 changelog.d/10940.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10940.misc b/changelog.d/10940.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10940.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 5d7c6fa858..eef354de6e 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -41,42 +41,112 @@ from synapse.types import StateMap, UserID, get_domain_from_id logger = logging.getLogger(__name__) -def check( - room_version_obj: RoomVersion, - event: EventBase, - auth_events: StateMap[EventBase], - do_sig_check: bool = True, - do_size_check: bool = True, +def validate_event_for_room_version( + room_version_obj: RoomVersion, event: EventBase ) -> None: - """Checks if this event is correctly authed. + """Ensure that the event complies with the limits, and has the right signatures + + NB: does not *validate* the signatures - it assumes that any signatures present + have already been checked. + + NB: it does not check that the event satisfies the auth rules (that is done in + check_auth_rules_for_event) - these tests are independent of the rest of the state + in the room. + + NB: This is used to check events that have been received over federation. As such, + it can only enforce the checks specified in the relevant room version, to avoid + a split-brain situation where some servers accept such events, and others reject + them. + + TODO: consider moving this into EventValidator Args: - room_version_obj: the version of the room - event: the event being checked. - auth_events: the existing room state. - do_sig_check: True if it should be verified that the sending server - signed the event. - do_size_check: True if the size of the event fields should be verified. + room_version_obj: the version of the room which contains this event + event: the event to be checked Raises: - AuthError if the checks fail - - Returns: - if the auth checks pass. + SynapseError if there is a problem with the event """ - assert isinstance(auth_events, dict) - - if do_size_check: - _check_size_limits(event) + _check_size_limits(event) if not hasattr(event, "room_id"): raise AuthError(500, "Event has no room_id: %s" % event) - room_id = event.room_id + # check that the event has the correct signatures + sender_domain = get_domain_from_id(event.sender) + + is_invite_via_3pid = ( + event.type == EventTypes.Member + and event.membership == Membership.INVITE + and "third_party_invite" in event.content + ) + + # Check the sender's domain has signed the event + if not event.signatures.get(sender_domain): + # We allow invites via 3pid to have a sender from a different + # HS, as the sender must match the sender of the original + # 3pid invite. This is checked further down with the + # other dedicated membership checks. + if not is_invite_via_3pid: + raise AuthError(403, "Event not signed by sender's server") + + if event.format_version in (EventFormatVersions.V1,): + # Only older room versions have event IDs to check. + event_id_domain = get_domain_from_id(event.event_id) + + # Check the origin domain has signed the event + if not event.signatures.get(event_id_domain): + raise AuthError(403, "Event not signed by sending server") + + is_invite_via_allow_rule = ( + room_version_obj.msc3083_join_rules + and event.type == EventTypes.Member + and event.membership == Membership.JOIN + and "join_authorised_via_users_server" in event.content + ) + if is_invite_via_allow_rule: + authoriser_domain = get_domain_from_id( + event.content["join_authorised_via_users_server"] + ) + if not event.signatures.get(authoriser_domain): + raise AuthError(403, "Event not signed by authorising server") + + +def check_auth_rules_for_event( + room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase] +) -> None: + """Check that an event complies with the auth rules + + Checks whether an event passes the auth rules with a given set of state events + + Assumes that we have already checked that the event is the right shape (it has + enough signatures, has a room ID, etc). In other words: + + - it's fine for use in state resolution, when we have already decided whether to + accept the event or not, and are now trying to decide whether it should make it + into the room state + + - when we're doing the initial event auth, it is only suitable in combination with + a bunch of other tests. + + Args: + room_version_obj: the version of the room + event: the event being checked. + auth_events: the room state to check the events against. + + Raises: + AuthError if the checks fail + """ + assert isinstance(auth_events, dict) # We need to ensure that the auth events are actually for the same room, to # stop people from using powers they've been granted in other rooms for # example. + # + # Arguably we don't need to do this when we're just doing state res, as presumably + # the state res algorithm isn't silly enough to give us events from different rooms. + # Still, it's easier to do it anyway. + room_id = event.room_id for auth_event in auth_events.values(): if auth_event.room_id != room_id: raise AuthError( @@ -86,45 +156,6 @@ def check( % (event.event_id, room_id, auth_event.event_id, auth_event.room_id), ) - if do_sig_check: - sender_domain = get_domain_from_id(event.sender) - - is_invite_via_3pid = ( - event.type == EventTypes.Member - and event.membership == Membership.INVITE - and "third_party_invite" in event.content - ) - - # Check the sender's domain has signed the event - if not event.signatures.get(sender_domain): - # We allow invites via 3pid to have a sender from a different - # HS, as the sender must match the sender of the original - # 3pid invite. This is checked further down with the - # other dedicated membership checks. - if not is_invite_via_3pid: - raise AuthError(403, "Event not signed by sender's server") - - if event.format_version in (EventFormatVersions.V1,): - # Only older room versions have event IDs to check. - event_id_domain = get_domain_from_id(event.event_id) - - # Check the origin domain has signed the event - if not event.signatures.get(event_id_domain): - raise AuthError(403, "Event not signed by sending server") - - is_invite_via_allow_rule = ( - room_version_obj.msc3083_join_rules - and event.type == EventTypes.Member - and event.membership == Membership.JOIN - and "join_authorised_via_users_server" in event.content - ) - if is_invite_via_allow_rule: - authoriser_domain = get_domain_from_id( - event.content["join_authorised_via_users_server"] - ) - if not event.signatures.get(authoriser_domain): - raise AuthError(403, "Event not signed by authorising server") - # Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules # # 1. If type is m.room.create: diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index cb81fa0986..d089c56286 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -22,7 +22,8 @@ from synapse.api.constants import ( RestrictedJoinRuleTypes, ) from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.api.room_versions import RoomVersion +from synapse.event_auth import check_auth_rules_for_event from synapse.events import EventBase from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext @@ -45,21 +46,17 @@ class EventAuthHandler: self._store = hs.get_datastore() self._server_name = hs.hostname - async def check_from_context( + async def check_auth_rules_from_context( self, - room_version: str, + room_version_obj: RoomVersion, event: EventBase, context: EventContext, - do_sig_check: bool = True, ) -> None: + """Check an event passes the auth rules at its own auth events""" auth_event_ids = event.auth_event_ids() auth_events_by_id = await self._store.get_events(auth_event_ids) auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()} - - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - event_auth.check( - room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check - ) + check_auth_rules_for_event(room_version_obj, event, auth_events) def compute_auth_events( self, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3b0b895b07..0a10a5c28a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -40,6 +40,10 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions from synapse.crypto.event_signing import compute_event_signature +from synapse.event_auth import ( + check_auth_rules_for_event, + validate_event_for_room_version, +) from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator @@ -742,10 +746,9 @@ class FederationHandler(BaseHandler): # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` - await self._event_auth_handler.check_from_context( - room_version.identifier, event, context, do_sig_check=False + await self._event_auth_handler.check_auth_rules_from_context( + room_version, event, context ) - return event async def on_invite_request( @@ -916,8 +919,8 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` - await self._event_auth_handler.check_from_context( - room_version_obj.identifier, event, context, do_sig_check=False + await self._event_auth_handler.check_auth_rules_from_context( + room_version_obj, event, context ) except AuthError as e: logger.warning("Failed to create new leave %r because %s", event, e) @@ -978,8 +981,8 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_knock_request` - await self._event_auth_handler.check_from_context( - room_version_obj.identifier, event, context, do_sig_check=False + await self._event_auth_handler.check_auth_rules_from_context( + room_version_obj, event, context ) except AuthError as e: logger.warning("Failed to create new knock %r because %s", event, e) @@ -1168,7 +1171,8 @@ class FederationHandler(BaseHandler): auth_for_e[(EventTypes.Create, "")] = create_event try: - event_auth.check(room_version, e, auth_events=auth_for_e) + validate_event_for_room_version(room_version, e) + check_auth_rules_for_event(room_version, e, auth_for_e) except SynapseError as err: # we may get SynapseErrors here as well as AuthErrors. For # instance, there are a couple of (ancient) events in some @@ -1266,8 +1270,9 @@ class FederationHandler(BaseHandler): event.internal_metadata.send_on_behalf_of = self.hs.hostname try: - await self._event_auth_handler.check_from_context( - room_version_obj.identifier, event, context + validate_event_for_room_version(room_version_obj, event) + await self._event_auth_handler.check_auth_rules_from_context( + room_version_obj, event, context ) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) @@ -1317,8 +1322,9 @@ class FederationHandler(BaseHandler): ) try: - await self._event_auth_handler.check_from_context( - room_version_obj.identifier, event, context + validate_event_for_room_version(room_version_obj, event) + await self._event_auth_handler.check_auth_rules_from_context( + room_version_obj, event, context ) except AuthError as e: logger.warning("Denying third party invite %r because %s", event, e) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 2c4644b4a3..e587b5b3b3 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -29,7 +29,6 @@ from typing import ( from prometheus_client import Counter -from synapse import event_auth from synapse.api.constants import ( EventContentFields, EventTypes, @@ -47,7 +46,11 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.event_auth import auth_types_for_event +from synapse.event_auth import ( + auth_types_for_event, + check_auth_rules_for_event, + validate_event_for_room_version, +) from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.federation.federation_client import InvalidResponseError @@ -1207,7 +1210,8 @@ class FederationEventHandler: context = EventContext.for_outlier() try: - event_auth.check(room_version_obj, event, auth_events=auth) + validate_event_for_room_version(room_version_obj, event) + check_auth_rules_for_event(room_version_obj, event, auth) except AuthError as e: logger.warning("Rejecting %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR @@ -1282,7 +1286,8 @@ class FederationEventHandler: auth_events_for_auth = calculated_auth_event_map try: - event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) + validate_event_for_room_version(room_version_obj, event) + check_auth_rules_for_event(room_version_obj, event, auth_events_for_auth) except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR @@ -1394,7 +1399,10 @@ class FederationEventHandler: } try: - event_auth.check(room_version_obj, event, auth_events=current_auth_events) + # TODO: skip the call to validate_event_for_room_version? we should already + # have validated the event. + validate_event_for_room_version(room_version_obj, event) + check_auth_rules_for_event(room_version_obj, event, current_auth_events) except AuthError as e: logger.warning( "Soft-failing %r (from %s) because %s", diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3b8cc50ec0..cdac53037c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -44,6 +44,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder +from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext @@ -1098,8 +1099,9 @@ class EventCreationHandler: assert event.content["membership"] == Membership.LEAVE else: try: - await self._event_auth_handler.check_from_context( - room_version_obj.identifier, event, context + validate_event_for_room_version(room_version_obj, event) + await self._event_auth_handler.check_auth_rules_from_context( + room_version_obj, event, context ) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index bf8a85f563..873e08258e 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -52,6 +52,7 @@ from synapse.api.errors import ( ) from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents from synapse.rest.admin._base import assert_user_is_admin @@ -238,8 +239,9 @@ class RoomCreationHandler(BaseHandler): }, ) old_room_version = await self.store.get_room_version(old_room_id) - await self._event_auth_handler.check_from_context( - old_room_version.identifier, tombstone_event, tombstone_context + validate_event_for_room_version(old_room_version, tombstone_event) + await self._event_auth_handler.check_auth_rules_from_context( + old_room_version, tombstone_event, tombstone_context ) await self.clone_existing_room( diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 92336d7cc8..017e6fd92d 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -329,12 +329,10 @@ def _resolve_auth_events( auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: # The signatures have already been checked at this point - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, event, auth_events, - do_sig_check=False, - do_size_check=False, ) prev_event = event except AuthError: @@ -349,12 +347,10 @@ def _resolve_normal_events( for event in _ordered_events(events): try: # The signatures have already been checked at this point - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, event, auth_events, - do_sig_check=False, - do_size_check=False, ) return event except AuthError: diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 7b1e8361de..586b0e12fe 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -546,12 +546,10 @@ async def _iterative_auth_checks( auth_events[key] = event_map[ev_id] try: - event_auth.check( + event_auth.check_auth_rules_for_event( room_version, event, auth_events, - do_sig_check=False, - do_size_check=False, ) resolved_state[(event.type, event.state_key)] = event_id diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 6ebd01bcbe..e7a7d00883 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -37,21 +37,19 @@ class EventAuthTestCase(unittest.TestCase): } # creator should be able to send state - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _random_state_event(creator), auth_events, - do_sig_check=False, ) # joiner should not be able to send state self.assertRaises( AuthError, - event_auth.check, + event_auth.check_auth_rules_for_event, RoomVersions.V1, _random_state_event(joiner), auth_events, - do_sig_check=False, ) def test_state_default_level(self): @@ -76,19 +74,17 @@ class EventAuthTestCase(unittest.TestCase): # pleb should not be able to send state self.assertRaises( AuthError, - event_auth.check, + event_auth.check_auth_rules_for_event, RoomVersions.V1, _random_state_event(pleb), auth_events, - do_sig_check=False, ), # king should be able to send state - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _random_state_event(king), auth_events, - do_sig_check=False, ) def test_alias_event(self): @@ -101,37 +97,33 @@ class EventAuthTestCase(unittest.TestCase): } # creator should be able to send aliases - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _alias_event(creator), auth_events, - do_sig_check=False, ) # Reject an event with no state key. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _alias_event(creator, state_key=""), auth_events, - do_sig_check=False, ) # If the domain of the sender does not match the state key, reject. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _alias_event(creator, state_key="test.com"), auth_events, - do_sig_check=False, ) # Note that the member does *not* need to be in the room. - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _alias_event(other), auth_events, - do_sig_check=False, ) def test_msc2432_alias_event(self): @@ -144,34 +136,30 @@ class EventAuthTestCase(unittest.TestCase): } # creator should be able to send aliases - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _alias_event(creator), auth_events, - do_sig_check=False, ) # No particular checks are done on the state key. - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _alias_event(creator, state_key=""), auth_events, - do_sig_check=False, ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _alias_event(creator, state_key="test.com"), auth_events, - do_sig_check=False, ) # Per standard auth rules, the member must be in the room. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _alias_event(other), auth_events, - do_sig_check=False, ) def test_msc2209(self): @@ -191,20 +179,18 @@ class EventAuthTestCase(unittest.TestCase): } # pleb should be able to modify the notifications power level. - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V1, _power_levels_event(pleb, {"notifications": {"room": 100}}), auth_events, - do_sig_check=False, ) # But an MSC2209 room rejects this change. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _power_levels_event(pleb, {"notifications": {"room": 100}}), auth_events, - do_sig_check=False, ) def test_join_rules_public(self): @@ -221,59 +207,53 @@ class EventAuthTestCase(unittest.TestCase): } # Check join. - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user cannot be force-joined to a room. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _member_event(pleb, "join", sender=creator), auth_events, - do_sig_check=False, ) # Banned should be rejected. auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user who left can re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can send a join if they're in the room. auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can accept an invite. auth_events[("m.room.member", pleb)] = _member_event( pleb, "invite", sender=creator ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) def test_join_rules_invite(self): @@ -291,60 +271,54 @@ class EventAuthTestCase(unittest.TestCase): # A join without an invite is rejected. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user cannot be force-joined to a room. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _member_event(pleb, "join", sender=creator), auth_events, - do_sig_check=False, ) # Banned should be rejected. auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user who left cannot re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can send a join if they're in the room. auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can accept an invite. auth_events[("m.room.member", pleb)] = _member_event( pleb, "invite", sender=creator ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) def test_join_rules_msc3083_restricted(self): @@ -369,11 +343,10 @@ class EventAuthTestCase(unittest.TestCase): # Older room versions don't understand this join rule with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), auth_events, - do_sig_check=False, ) # A properly formatted join event should work. @@ -383,11 +356,10 @@ class EventAuthTestCase(unittest.TestCase): "join_authorised_via_users_server": "@creator:example.com" }, ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, auth_events, - do_sig_check=False, ) # A join issued by a specific user works (i.e. the power level checks @@ -399,7 +371,7 @@ class EventAuthTestCase(unittest.TestCase): pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event( "@inviter:foo.test" ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event( pleb, @@ -408,16 +380,14 @@ class EventAuthTestCase(unittest.TestCase): }, ), pl_auth_events, - do_sig_check=False, ) # A join which is missing an authorised server is rejected. with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), auth_events, - do_sig_check=False, ) # An join authorised by a user who is not in the room is rejected. @@ -426,7 +396,7 @@ class EventAuthTestCase(unittest.TestCase): creator, {"invite": 100, "users": {"@other:example.com": 150}} ) with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event( pleb, @@ -435,13 +405,12 @@ class EventAuthTestCase(unittest.TestCase): }, ), auth_events, - do_sig_check=False, ) # A user cannot be force-joined to a room. (This uses an event which # *would* be valid, but is sent be a different user.) with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _member_event( pleb, @@ -452,36 +421,32 @@ class EventAuthTestCase(unittest.TestCase): }, ), auth_events, - do_sig_check=False, ) # Banned should be rejected. auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, auth_events, - do_sig_check=False, ) # A user who left can re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, auth_events, - do_sig_check=False, ) # A user can send a join if they're in the room. (This doesn't need to # be authorised since the user is already joined.) auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), auth_events, - do_sig_check=False, ) # A user can accept an invite. (This doesn't need to be authorised since @@ -489,11 +454,10 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event( pleb, "invite", sender=creator ) - event_auth.check( + event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), auth_events, - do_sig_check=False, ) -- cgit 1.5.1 From 29364145b29e84c5dcab076c4e0d436ebf77e4cd Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 30 Sep 2021 12:51:47 +0100 Subject: Pass str to twisted's IReactorTCP (#10895) This follows a correction made in twisted/twisted#1664 and should fix our Twisted Trial CI job. Until that change is in a twisted release, we'll have to ignore the type of the `host` argument. I've raised #10899 to remind us to review the issue in a few months' time. --- changelog.d/10895.misc | 1 + synapse/handlers/send_email.py | 9 +++++++-- synapse/replication/tcp/handler.py | 8 ++++++-- synapse/replication/tcp/redis.py | 8 +++++++- tests/replication/_base.py | 4 ++-- tests/server.py | 8 ++++---- 6 files changed, 27 insertions(+), 11 deletions(-) create mode 100644 changelog.d/10895.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10895.misc b/changelog.d/10895.misc new file mode 100644 index 0000000000..d1c8224980 --- /dev/null +++ b/changelog.d/10895.misc @@ -0,0 +1 @@ +Fix type hints to be compatible with an upcoming change to Twisted. \ No newline at end of file diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index 25e6b012b7..1a062a784c 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -105,8 +105,13 @@ async def _sendmail( # set to enable TLS. factory = build_sender_factory(hostname=smtphost if enable_tls else None) - # the IReactorTCP interface claims host has to be a bytes, which seems to be wrong - reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type] + reactor.connectTCP( + smtphost, # type: ignore[arg-type] + smtpport, + factory, + timeout=30, + bindAddress=None, + ) await make_deferred_yieldable(d) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1438a82b60..d64d1dbacd 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -315,7 +315,7 @@ class ReplicationCommandHandler: hs, outbound_redis_connection ) hs.get_reactor().connectTCP( - hs.config.redis.redis_host.encode(), + hs.config.redis.redis_host, # type: ignore[arg-type] hs.config.redis.redis_port, self._factory, ) @@ -324,7 +324,11 @@ class ReplicationCommandHandler: self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) host = hs.config.worker.worker_replication_host port = hs.config.worker.worker_replication_port - hs.get_reactor().connectTCP(host.encode(), port, self._factory) + hs.get_reactor().connectTCP( + host, # type: ignore[arg-type] + port, + self._factory, + ) def get_streams(self) -> Dict[str, Stream]: """Get a map from stream name to all streams.""" diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 8c0df627c8..062fe2f33e 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -364,6 +364,12 @@ def lazyConnection( factory.continueTrying = reconnect reactor = hs.get_reactor() - reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None) + reactor.connectTCP( + host, # type: ignore[arg-type] + port, + factory, + timeout=30, + bindAddress=None, + ) return factory.handler diff --git a/tests/replication/_base.py b/tests/replication/_base.py index c7555c26db..cdd6e3d3c1 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -240,7 +240,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): if self.hs.config.redis.redis_enabled: # Handle attempts to connect to fake redis server. self.reactor.add_tcp_client_callback( - b"localhost", + "localhost", 6379, self.connect_any_redis_attempts, ) @@ -424,7 +424,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): clients = self.reactor.tcpClients while clients: (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, b"localhost") + self.assertEqual(host, "localhost") self.assertEqual(port, 6379) client_protocol = client_factory.buildProtocol(None) diff --git a/tests/server.py b/tests/server.py index 88dfa8058e..64645651ce 100644 --- a/tests/server.py +++ b/tests/server.py @@ -317,7 +317,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def __init__(self): self.threadpool = ThreadPool(self) - self._tcp_callbacks = {} + self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {} self._udp = [] self.lookups: Dict[str, str] = {} self._thread_callbacks: Deque[Callable[[], None]] = deque() @@ -355,7 +355,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def getThreadPool(self): return self.threadpool - def add_tcp_client_callback(self, host, port, callback): + def add_tcp_client_callback(self, host: str, port: int, callback: Callable): """Add a callback that will be invoked when we receive a connection attempt to the given IP/port using `connectTCP`. @@ -364,7 +364,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): """ self._tcp_callbacks[(host, port)] = callback - def connectTCP(self, host, port, factory, timeout=30, bindAddress=None): + def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None): """Fake L{IReactorTCP.connectTCP}.""" conn = super().connectTCP( @@ -475,7 +475,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): return server -def get_clock(): +def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: clock = ThreadedMemoryReactorClock() hs_clock = Clock(clock) return clock, hs_clock -- cgit 1.5.1 From d1bf5f7c9d669fcf60aadc2c6527447adef2c43c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Sep 2021 11:13:59 -0400 Subject: Strip "join_authorised_via_users_server" from join events which do not need it. (#10933) This fixes a "Event not signed by authorising server" error when transition room member from join -> join, e.g. when updating a display name or avatar URL for restricted rooms. --- changelog.d/10933.bugfix | 1 + synapse/api/constants.py | 3 +++ synapse/event_auth.py | 12 +++++++----- synapse/events/utils.py | 2 +- synapse/federation/federation_base.py | 6 +++--- synapse/federation/federation_client.py | 6 +++--- synapse/federation/federation_server.py | 6 +++--- synapse/handlers/federation.py | 9 +++++++-- synapse/handlers/room_member.py | 10 +++++++++- tests/events/test_utils.py | 7 ++++--- tests/test_event_auth.py | 9 +++++---- 11 files changed, 46 insertions(+), 25 deletions(-) create mode 100644 changelog.d/10933.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/10933.bugfix b/changelog.d/10933.bugfix new file mode 100644 index 0000000000..e0694fea22 --- /dev/null +++ b/changelog.d/10933.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.40.0 where changing a user's display name or avatar in a restricted room would cause an authentication error. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 39fd9954d5..a31f037748 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -217,6 +217,9 @@ class EventContentFields: # For "marker" events MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion" + # The authorising user for joining a restricted room. + AUTHORISING_USER = "join_authorised_via_users_server" + class RoomTypes: """Understood values of the room_type field of m.room.create events.""" diff --git a/synapse/event_auth.py b/synapse/event_auth.py index eef354de6e..7a1adc2750 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -102,11 +102,11 @@ def validate_event_for_room_version( room_version_obj.msc3083_join_rules and event.type == EventTypes.Member and event.membership == Membership.JOIN - and "join_authorised_via_users_server" in event.content + and EventContentFields.AUTHORISING_USER in event.content ) if is_invite_via_allow_rule: authoriser_domain = get_domain_from_id( - event.content["join_authorised_via_users_server"] + event.content[EventContentFields.AUTHORISING_USER] ) if not event.signatures.get(authoriser_domain): raise AuthError(403, "Event not signed by authorising server") @@ -413,7 +413,9 @@ def _is_membership_change_allowed( # Note that if the caller is in the room or invited, then they do # not need to meet the allow rules. if not caller_in_room and not caller_invited: - authorising_user = event.content.get("join_authorised_via_users_server") + authorising_user = event.content.get( + EventContentFields.AUTHORISING_USER + ) if authorising_user is None: raise AuthError(403, "Join event is missing authorising user.") @@ -868,10 +870,10 @@ def auth_types_for_event( auth_types.add(key) if room_version.msc3083_join_rules and membership == Membership.JOIN: - if "join_authorised_via_users_server" in event.content: + if EventContentFields.AUTHORISING_USER in event.content: key = ( EventTypes.Member, - event.content["join_authorised_via_users_server"], + event.content[EventContentFields.AUTHORISING_USER], ) auth_types.add(key) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a13fb0148f..520edbbf61 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -105,7 +105,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: if event_type == EventTypes.Member: add_fields("membership") if room_version.msc3375_redaction_rules: - add_fields("join_authorised_via_users_server") + add_fields(EventContentFields.AUTHORISING_USER) elif event_type == EventTypes.Create: # MSC2176 rules state that create events cannot be redacted. if room_version.msc2176_redaction_rules: diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 024e440ff4..0cd424e12a 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -15,7 +15,7 @@ import logging from collections import namedtuple -from synapse.api.constants import MAX_DEPTH, EventTypes, Membership +from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.crypto.event_signing import check_event_content_hash @@ -184,10 +184,10 @@ async def _check_sigs_on_pdu( room_version.msc3083_join_rules and pdu.type == EventTypes.Member and pdu.membership == Membership.JOIN - and "join_authorised_via_users_server" in pdu.content + and EventContentFields.AUTHORISING_USER in pdu.content ): authorising_server = get_domain_from_id( - pdu.content["join_authorised_via_users_server"] + pdu.content[EventContentFields.AUTHORISING_USER] ) try: await keyring.verify_event_for_server( diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 584836c04a..2ab4dec88f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -37,7 +37,7 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import ( CodeMessageException, Codes, @@ -875,9 +875,9 @@ class FederationClient(FederationBase): # If the join is being authorised via allow rules, we need to send # the /send_join back to the same server that was originally used # with /make_join. - if "join_authorised_via_users_server" in pdu.content: + if EventContentFields.AUTHORISING_USER in pdu.content: destinations = [ - get_domain_from_id(pdu.content["join_authorised_via_users_server"]) + get_domain_from_id(pdu.content[EventContentFields.AUTHORISING_USER]) ] return await self._try_destination_list( diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 83f11d6b88..d8c0b86f23 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -34,7 +34,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EduTypes, EventTypes, Membership +from synapse.api.constants import EduTypes, EventContentFields, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -765,11 +765,11 @@ class FederationServer(FederationBase): if ( room_version.msc3083_join_rules and event.membership == Membership.JOIN - and "join_authorised_via_users_server" in event.content + and EventContentFields.AUTHORISING_USER in event.content ): # We can only authorise our own users. authorising_server = get_domain_from_id( - event.content["join_authorised_via_users_server"] + event.content[EventContentFields.AUTHORISING_USER] ) if authorising_server != self.server_name: raise SynapseError( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0a10a5c28a..043ca4a224 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -27,7 +27,12 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, RejectedReason +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RejectedReason, +) from synapse.api.errors import ( AuthError, CodeMessageException, @@ -716,7 +721,7 @@ class FederationHandler(BaseHandler): if include_auth_user_id: event_content[ - "join_authorised_via_users_server" + EventContentFields.AUTHORISING_USER ] = await self._event_auth_handler.get_user_which_could_invite( room_id, state_ids, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 02103f6c9a..29b3e41cc9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -573,6 +573,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): errcode=Codes.BAD_JSON, ) + # The event content should *not* include the authorising user as + # it won't be properly signed. Strip it out since it might come + # back from a client updating a display name / avatar. + # + # This only applies to restricted rooms, but there should be no reason + # for a client to include it. Unconditionally remove it. + content.pop(EventContentFields.AUTHORISING_USER, None) + effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" @@ -939,7 +947,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # be included in the event content in order to efficiently validate # the event. content[ - "join_authorised_via_users_server" + EventContentFields.AUTHORISING_USER ] = await self.event_auth_handler.get_user_which_could_invite( room_id, current_state_ids, diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 5446fda5e7..1dea09e480 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.utils import ( @@ -352,7 +353,7 @@ class PruneEventTestCase(unittest.TestCase): "event_id": "$test:domain", "content": { "membership": "join", - "join_authorised_via_users_server": "@user:domain", + EventContentFields.AUTHORISING_USER: "@user:domain", "other_key": "stripped", }, }, @@ -372,7 +373,7 @@ class PruneEventTestCase(unittest.TestCase): "type": "m.room.member", "content": { "membership": "join", - "join_authorised_via_users_server": "@user:domain", + EventContentFields.AUTHORISING_USER: "@user:domain", "other_key": "stripped", }, }, @@ -380,7 +381,7 @@ class PruneEventTestCase(unittest.TestCase): "type": "m.room.member", "content": { "membership": "join", - "join_authorised_via_users_server": "@user:domain", + EventContentFields.AUTHORISING_USER: "@user:domain", }, "signatures": {}, "unsigned": {}, diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index e7a7d00883..cf407c51cf 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -16,6 +16,7 @@ import unittest from typing import Optional from synapse import event_auth +from synapse.api.constants import EventContentFields from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict @@ -353,7 +354,7 @@ class EventAuthTestCase(unittest.TestCase): authorised_join_event = _join_event( pleb, additional_content={ - "join_authorised_via_users_server": "@creator:example.com" + EventContentFields.AUTHORISING_USER: "@creator:example.com" }, ) event_auth.check_auth_rules_for_event( @@ -376,7 +377,7 @@ class EventAuthTestCase(unittest.TestCase): _join_event( pleb, additional_content={ - "join_authorised_via_users_server": "@inviter:foo.test" + EventContentFields.AUTHORISING_USER: "@inviter:foo.test" }, ), pl_auth_events, @@ -401,7 +402,7 @@ class EventAuthTestCase(unittest.TestCase): _join_event( pleb, additional_content={ - "join_authorised_via_users_server": "@other:example.com" + EventContentFields.AUTHORISING_USER: "@other:example.com" }, ), auth_events, @@ -417,7 +418,7 @@ class EventAuthTestCase(unittest.TestCase): "join", sender=creator, additional_content={ - "join_authorised_via_users_server": "@inviter:foo.test" + EventContentFields.AUTHORISING_USER: "@inviter:foo.test" }, ), auth_events, -- cgit 1.5.1 From 9e5a429c8b082d4cfbc0bd04c1ddde8822fd96b4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Sep 2021 14:06:02 -0400 Subject: Clean-up registration tests (#10945) Uses `override_config` and fixes test_auto_create_auto_join_where_no_consent to properly configure auto-join rooms. --- changelog.d/10945.misc | 1 + synapse/handlers/register.py | 4 +- tests/handlers/test_register.py | 89 ++++++++++++++++++++++++----------------- 3 files changed, 56 insertions(+), 38 deletions(-) create mode 100644 changelog.d/10945.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10945.misc b/changelog.d/10945.misc new file mode 100644 index 0000000000..7cf1f02ad6 --- /dev/null +++ b/changelog.d/10945.misc @@ -0,0 +1 @@ +Fix a broken test to ensure that consent configuration works during registration. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 4a7ccb882e..cb4eb0720b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -340,6 +340,8 @@ class RegistrationHandler(BaseHandler): auth_provider=(auth_provider_id or ""), ).inc() + # If the user does not need to consent at registration, auto-join any + # configured rooms. if not self.hs.config.consent.user_consent_at_registration: if not self.hs.config.auto_join_rooms_for_guests and make_guest: logger.info( @@ -387,7 +389,7 @@ class RegistrationHandler(BaseHandler): "preset": self.hs.config.registration.autocreate_auto_join_room_preset, } - # If the configuration providers a user ID to create rooms with, use + # If the configuration provides a user ID to create rooms with, use # that instead of the first user registered. requires_join = False if self.hs.config.registration.auto_join_user_id: diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index bd05a2c2d1..db691c4c1c 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -16,7 +16,12 @@ from unittest.mock import Mock from synapse.api.auth import Auth from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, ResourceLimitError, SynapseError +from synapse.api.errors import ( + CodeMessageException, + Codes, + ResourceLimitError, + SynapseError, +) from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, RoomID, UserID, create_requester @@ -120,14 +125,24 @@ class RegistrationTestCase(unittest.HomeserverTestCase): hs_config = self.default_config() # some of the tests rely on us having a user consent version - hs_config["user_consent"] = { - "version": "test_consent_version", - "template_dir": ".", - } + hs_config.setdefault("user_consent", {}).update( + { + "version": "test_consent_version", + "template_dir": ".", + } + ) hs_config["max_mau_value"] = 50 hs_config["limit_usage_by_mau"] = True - hs = self.setup_test_homeserver(config=hs_config) + # Don't attempt to reach out over federation. + self.mock_federation_client = Mock() + self.mock_federation_client.make_query.side_effect = CodeMessageException( + 500, "" + ) + + hs = self.setup_test_homeserver( + config=hs_config, federation_client=self.mock_federation_client + ) load_legacy_spam_checkers(hs) @@ -138,9 +153,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs): - self.mock_distributor = Mock() - self.mock_distributor.declare("registered_user") - self.mock_captcha_client = Mock() self.handler = self.hs.get_registration_handler() self.store = self.hs.get_datastore() self.lots_of_users = 100 @@ -174,21 +186,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) + @override_config({"limit_usage_by_mau": False}) def test_mau_limits_when_disabled(self): - self.hs.config.server.limit_usage_by_mau = False # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "a", "display_name")) + @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self): - self.hs.config.server.limit_usage_by_mau = True self.store.count_monthly_users = Mock( return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) + @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_blocked(self): - self.hs.config.server.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -205,8 +217,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) + @override_config({"limit_usage_by_mau": True}) def test_register_mau_blocked(self): - self.hs.config.server.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -221,10 +233,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart="local_part"), ResourceLimitError ) + @override_config( + {"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False} + ) def test_auto_join_rooms_for_guests(self): - room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] - self.hs.config.auto_join_rooms_for_guests = False user_id = self.get_success( self.handler.register_user(localpart="jeff", make_guest=True), ) @@ -243,34 +255,33 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(room_id["room_id"] in rooms) self.assertEqual(len(rooms), 1) + @override_config({"auto_join_rooms": []}) def test_auto_create_auto_join_rooms_with_no_rooms(self): - self.hs.config.auto_join_rooms = [] frank = UserID.from_string("@frank:test") user_id = self.get_success(self.handler.register_user(frank.localpart)) self.assertEqual(user_id, frank.to_string()) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + @override_config({"auto_join_rooms": ["#room:another"]}) def test_auto_create_auto_join_where_room_is_another_domain(self): - self.hs.config.auto_join_rooms = ["#room:another"] frank = UserID.from_string("@frank:test") user_id = self.get_success(self.handler.register_user(frank.localpart)) self.assertEqual(user_id, frank.to_string()) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + @override_config( + {"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False} + ) def test_auto_create_auto_join_where_auto_create_is_false(self): - self.hs.config.autocreate_auto_join_rooms = False - room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] user_id = self.get_success(self.handler.register_user(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + @override_config({"auto_join_rooms": ["#room:test"]}) def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self): room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=make_awaitable(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) @@ -294,10 +305,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(room_id["room_id"] in rooms) self.assertEqual(len(rooms), 1) + @override_config({"auto_join_rooms": ["#room:test"]}) def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self): - room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=make_awaitable(2)) self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) @@ -510,6 +519,17 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertEqual(rooms, set()) self.assertEqual(invited_rooms, []) + @override_config( + { + "user_consent": { + "block_events_error": "Error", + "require_at_registration": True, + }, + "form_secret": "53cr3t", + "public_baseurl": "http://test", + "auto_join_rooms": ["#room:test"], + }, + ) def test_auto_create_auto_join_where_no_consent(self): """Test to ensure that the first user is not auto-joined to a room if they have not given general consent. @@ -521,25 +541,20 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # * The server is configured to auto-join to a room # (and autocreate if necessary) - event_creation_handler = self.hs.get_event_creation_handler() - # (Messing with the internals of event_creation_handler is fragile - # but can't see a better way to do this. One option could be to subclass - # the test with custom config.) - event_creation_handler._block_events_without_consent_error = "Error" - event_creation_handler._consent_uri_builder = Mock() - room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] - # When:- - # * the user is registered and post consent actions are called + # * the user is registered user_id = self.get_success(self.handler.register_user(localpart="jeff")) - self.get_success(self.handler.post_consent_actions(user_id)) # Then:- # * Ensure that they have not been joined to the room rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + # The user provides consent; ensure they are now in the rooms. + self.get_success(self.handler.post_consent_actions(user_id)) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 1) + def test_register_support_user(self): user_id = self.get_success( self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT) -- cgit 1.5.1 From a0f48ee89d88fd7b6da8023dbba607a69073152e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 4 Oct 2021 07:18:54 -0400 Subject: Use direct references for configuration variables (part 7). (#10959) --- changelog.d/10959.misc | 1 + synapse/handlers/auth.py | 2 +- synapse/handlers/identity.py | 13 ++++++++++--- synapse/handlers/profile.py | 4 ++-- synapse/handlers/register.py | 9 ++++++--- synapse/handlers/room_member.py | 2 +- synapse/handlers/ui_auth/checkers.py | 14 ++++++++------ synapse/rest/admin/users.py | 4 ++-- synapse/rest/client/account.py | 22 +++++++++++----------- synapse/rest/client/auth.py | 6 ++++-- synapse/rest/client/capabilities.py | 6 +++--- synapse/rest/client/login.py | 6 +++--- synapse/rest/client/register.py | 26 +++++++++++++------------- synapse/rest/well_known.py | 4 ++-- synapse/storage/databases/main/registration.py | 2 +- synapse/util/threepids.py | 4 ++-- tests/config/test_load.py | 6 +++--- tests/handlers/test_profile.py | 4 ++-- tests/rest/admin/test_user.py | 4 ++-- tests/rest/client/test_account.py | 4 ++-- tests/rest/client/test_identity.py | 2 +- tests/rest/client/test_register.py | 4 ++-- tests/unittest.py | 2 +- 23 files changed, 83 insertions(+), 68 deletions(-) create mode 100644 changelog.d/10959.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10959.misc b/changelog.d/10959.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10959.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a8c717efd5..2d0f3d566c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -198,7 +198,7 @@ class AuthHandler(BaseHandler): if inst.is_enabled(): self.checkers[inst.AUTH_TYPE] = inst # type: ignore - self.bcrypt_rounds = hs.config.bcrypt_rounds + self.bcrypt_rounds = hs.config.registration.bcrypt_rounds # we can't use hs.get_module_api() here, because to do so will create an # import loop. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index a0640fcac0..c881475c25 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -573,9 +573,15 @@ class IdentityHandler(BaseHandler): # Try to validate as email if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + # Remote emails will only be used if a valid identity server is provided. + assert ( + self.hs.config.registration.account_threepid_delegate_email is not None + ) + # Ask our delegated email identity server validation_session = await self.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds + self.hs.config.registration.account_threepid_delegate_email, + threepid_creds, ) elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: # Get a validated session matching these details @@ -587,10 +593,11 @@ class IdentityHandler(BaseHandler): return validation_session # Try to validate as msisdn - if self.hs.config.account_threepid_delegate_msisdn: + if self.hs.config.registration.account_threepid_delegate_msisdn: # Ask our delegated msisdn identity server validation_session = await self.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds + self.hs.config.registration.account_threepid_delegate_msisdn, + threepid_creds, ) return validation_session diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 425c0d4973..2e19706c69 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -178,7 +178,7 @@ class ProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") - if not by_admin and not self.hs.config.enable_set_displayname: + if not by_admin and not self.hs.config.registration.enable_set_displayname: profile = await self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( @@ -268,7 +268,7 @@ class ProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") - if not by_admin and not self.hs.config.enable_set_avatar_url: + if not by_admin and not self.hs.config.registration.enable_set_avatar_url: profile = await self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: raise SynapseError( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cb4eb0720b..441af7a848 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -116,8 +116,8 @@ class RegistrationHandler(BaseHandler): self._register_device_client = self.register_device_inner self.pusher_pool = hs.get_pusherpool() - self.session_lifetime = hs.config.session_lifetime - self.access_token_lifetime = hs.config.access_token_lifetime + self.session_lifetime = hs.config.registration.session_lifetime + self.access_token_lifetime = hs.config.registration.access_token_lifetime init_counters_for_auth_provider("") @@ -343,7 +343,10 @@ class RegistrationHandler(BaseHandler): # If the user does not need to consent at registration, auto-join any # configured rooms. if not self.hs.config.consent.user_consent_at_registration: - if not self.hs.config.auto_join_rooms_for_guests and make_guest: + if ( + not self.hs.config.registration.auto_join_rooms_for_guests + and make_guest + ): logger.info( "Skipping auto-join for %s because auto-join for guests is disabled", user_id, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 29b3e41cc9..c8fb24a20c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -89,7 +89,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.spam_checker = hs.get_spam_checker() self.third_party_event_rules = hs.get_third_party_event_rules() self._server_notices_mxid = self.config.servernotices.server_notices_mxid - self._enable_lookup = hs.config.enable_3pid_lookup + self._enable_lookup = hs.config.registration.enable_3pid_lookup self.allow_per_room_profiles = self.config.server.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 8f5d465fa1..184730ebe8 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -153,21 +153,23 @@ class _BaseThreepidAuthChecker: # msisdns are currently always ThreepidBehaviour.REMOTE if medium == "msisdn": - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: raise SynapseError( 400, "Phone number verification is not enabled on this homeserver" ) threepid = await identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds + self.hs.config.registration.account_threepid_delegate_msisdn, + threepid_creds, ) elif medium == "email": if ( self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE ): - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email threepid = await identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds + self.hs.config.registration.account_threepid_delegate_email, + threepid_creds, ) elif ( self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL @@ -240,7 +242,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): _BaseThreepidAuthChecker.__init__(self, hs) def is_enabled(self) -> bool: - return bool(self.hs.config.account_threepid_delegate_msisdn) + return bool(self.hs.config.registration.account_threepid_delegate_msisdn) async def check_auth(self, authdict: dict, clientip: str) -> Any: return await self._check_threepid("msisdn", authdict) @@ -252,7 +254,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs - self._enabled = bool(hs.config.registration_requires_token) + self._enabled = bool(hs.config.registration.registration_requires_token) self.store = hs.get_datastore() def is_enabled(self) -> bool: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 46bfec4623..f20aa65301 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -442,7 +442,7 @@ class UserRegisterServlet(RestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._clear_old_nonces() - if not self.hs.config.registration_shared_secret: + if not self.hs.config.registration.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") body = parse_json_object_from_request(request) @@ -498,7 +498,7 @@ class UserRegisterServlet(RestServlet): got_mac = body["mac"] want_mac_builder = hmac.new( - key=self.hs.config.registration_shared_secret.encode(), + key=self.hs.config.registration.registration_shared_secret.encode(), digestmod=hashlib.sha1, ) want_mac_builder.update(nonce.encode("utf8")) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index fff133ef10..6b272658fc 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -130,11 +130,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -414,11 +414,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -496,7 +496,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" @@ -507,7 +507,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): ) ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, + self.hs.config.registration.account_threepid_delegate_msisdn, country, phone_number, client_secret, @@ -604,7 +604,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: - if not self.config.account_threepid_delegate_msisdn: + if not self.config.registration.account_threepid_delegate_msisdn: raise SynapseError( 400, "This homeserver is not validating phone numbers. Use an identity server " @@ -617,7 +617,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): # Proxy submit_token request to msisdn threepid delegate response = await self.identity_handler.proxy_msisdn_submit_token( - self.config.account_threepid_delegate_msisdn, + self.config.registration.account_threepid_delegate_msisdn, body["client_secret"], body["sid"], body["token"], @@ -644,7 +644,7 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) @@ -693,7 +693,7 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) @@ -801,7 +801,7 @@ class ThreepidDeleteRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 282861fae2..c9ad35a3ad 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -49,8 +49,10 @@ class AuthRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.captcha.recaptcha_template self.terms_template = hs.config.terms_template - self.registration_token_template = hs.config.registration_token_template - self.success_template = hs.config.fallback_success_template + self.registration_token_template = ( + hs.config.registration.registration_token_template + ) + self.success_template = hs.config.registration.fallback_success_template async def on_GET(self, request: SynapseRequest, stagetype: str) -> None: session = parse_string(request, "session") diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index d6b6256413..2a3e24ae7e 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -64,13 +64,13 @@ class CapabilitiesRestServlet(RestServlet): if self.config.experimental.msc3283_enabled: response["capabilities"]["org.matrix.msc3283.set_displayname"] = { - "enabled": self.config.enable_set_displayname + "enabled": self.config.registration.enable_set_displayname } response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = { - "enabled": self.config.enable_set_avatar_url + "enabled": self.config.registration.enable_set_avatar_url } response["capabilities"]["org.matrix.msc3283.3pid_changes"] = { - "enabled": self.config.enable_3pid_changes + "enabled": self.config.registration.enable_3pid_changes } return 200, response diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index fa5c173f4b..d49a647b03 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -79,7 +79,7 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._msc2918_enabled = hs.config.access_token_lifetime is not None + self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None self.auth = hs.get_auth() @@ -447,7 +447,7 @@ class RefreshTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() self._clock = hs.get_clock() - self.access_token_lifetime = hs.config.access_token_lifetime + self.access_token_lifetime = hs.config.registration.access_token_lifetime async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) @@ -556,7 +556,7 @@ class CasTicketServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) - if hs.config.access_token_lifetime is not None: + if hs.config.registration.access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas.cas_enabled: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index a6eb6f6410..bf3cb34146 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -140,11 +140,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -221,7 +221,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): 400, "Phone number is already in use", Codes.THREEPID_IN_USE ) - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" @@ -231,7 +231,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, + self.hs.config.registration.account_threepid_delegate_msisdn, country, phone_number, client_secret, @@ -341,7 +341,7 @@ class UsernameAvailabilityRestServlet(RestServlet): ) async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_registration: + if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) @@ -391,7 +391,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) - if not self.hs.config.enable_registration: + if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) @@ -419,8 +419,8 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() - self._registration_enabled = self.hs.config.enable_registration - self._msc2918_enabled = hs.config.access_token_lifetime is not None + self._registration_enabled = self.hs.config.registration.enable_registration + self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -800,7 +800,7 @@ class RegisterRestServlet(RestServlet): async def _do_guest_registration( self, params: JsonDict, address: Optional[str] = None ) -> Tuple[int, JsonDict]: - if not self.hs.config.allow_guest_access: + if not self.hs.config.registration.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id = await self.registration_handler.register_user( make_guest=True, address=address @@ -849,13 +849,13 @@ def _calculate_registration_flows( """ # FIXME: need a better error than "no auth flow found" for scenarios # where we required 3PID for registration but the user didn't give one - require_email = "email" in config.registrations_require_3pid - require_msisdn = "msisdn" in config.registrations_require_3pid + require_email = "email" in config.registration.registrations_require_3pid + require_msisdn = "msisdn" in config.registration.registrations_require_3pid show_msisdn = True show_email = True - if config.disable_msisdn_registration: + if config.registration.disable_msisdn_registration: show_msisdn = False require_msisdn = False @@ -909,7 +909,7 @@ def _calculate_registration_flows( flow.insert(0, LoginType.RECAPTCHA) # Prepend registration token to all flows if we're requiring a token - if config.registration_requires_token: + if config.registration.registration_requires_token: for flow in flows: flow.insert(0, LoginType.REGISTRATION_TOKEN) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index c80a3a99aa..7ac01faab4 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -39,9 +39,9 @@ class WellKnownBuilder: result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}} - if self._config.default_identity_server: + if self._config.registration.default_identity_server: result["m.identity_server"] = { - "base_url": self._config.default_identity_server + "base_url": self._config.registration.default_identity_server } return result diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7279b0924e..de262fbf5a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1710,7 +1710,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): We do this by grandfathering in existing user threepids assuming that they used one of the server configured trusted identity servers. """ - id_servers = set(self.config.trusted_third_party_id_servers) + id_servers = set(self.config.registration.trusted_third_party_id_servers) def _bg_user_threepids_grandfather_txn(txn): sql = """ diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index baa9190a9a..389adf00f6 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -44,8 +44,8 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: bool: whether the 3PID medium/address is allowed to be added to this HS """ - if hs.config.allowed_local_3pids: - for constraint in hs.config.allowed_local_3pids: + if hs.config.registration.allowed_local_3pids: + for constraint in hs.config.registration.allowed_local_3pids: logger.debug( "Checking 3PID %s (%s) against %s (%s)", address, diff --git a/tests/config/test_load.py b/tests/config/test_load.py index ef6c2beec7..8e49ca26d9 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -84,16 +84,16 @@ class ConfigLoadingTestCase(unittest.TestCase): ) # Check that disable_registration clobbers enable_registration. config = HomeServerConfig.load_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) # Check that either config value is clobbered by the command line. config = HomeServerConfig.load_or_generate_config( "", ["-c", self.file, "--enable-registration"] ) - self.assertTrue(config.enable_registration) + self.assertTrue(config.registration.enable_registration) def test_stats_enabled(self): self.generate_config_and_remove_lines_containing("enable_metrics") diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 57cc3e2646..c153018fd8 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -110,7 +110,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_name_if_disabled(self): - self.hs.config.enable_set_displayname = False + self.hs.config.registration.enable_set_displayname = False # Setting displayname for the first time is allowed self.get_success( @@ -225,7 +225,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_avatar_if_disabled(self): - self.hs.config.enable_set_avatar_url = False + self.hs.config.registration.enable_set_avatar_url = False # Setting displayname for the first time is allowed self.get_success( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index a285d5a7fe..6ed9e42173 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -59,7 +59,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver() - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" self.hs.get_media_repository = Mock() self.hs.get_deactivate_account_handler = Mock() @@ -71,7 +71,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): If there is no shared secret, registration through this method will be prevented. """ - self.hs.config.registration_shared_secret = None + self.hs.config.registration.registration_shared_secret = None channel = self.make_request("POST", self.url, b"{}") diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 2f44547bfb..89d85b0a17 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -664,7 +664,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -734,7 +734,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_delete_email_if_disabled(self): """Test deleting an email from profile when disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False # Add a threepid self.get_success( diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index ca2e8ff8ef..becb4e8dcc 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -37,7 +37,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): return self.hs def test_3pid_lookup_disabled(self): - self.hs.config.enable_3pid_lookup = False + self.hs.config.registration.enable_3pid_lookup = False self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index af135d57e1..66dcfc9f88 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -147,7 +147,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_guest_registration(self): self.hs.config.key.macaroon_secret_key = "test" - self.hs.config.allow_guest_access = True + self.hs.config.registration.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") @@ -156,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self): - self.hs.config.allow_guest_access = False + self.hs.config.registration.allow_guest_access = False channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") diff --git a/tests/unittest.py b/tests/unittest.py index 0807467e39..1f803564f6 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -560,7 +560,7 @@ class HomeserverTestCase(TestCase): Returns: The MXID of the new user. """ - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" # Create the user channel = self.make_request("GET", "/_synapse/admin/v1/register") -- cgit 1.5.1 From f7b034a24bd5e64f05934453fe7b072894e124db Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 4 Oct 2021 12:45:51 +0100 Subject: Consistently exclude from user_directory (#10960) * Introduce `should_include_local_users_in_dir` We exclude three kinds of local users from the user_directory tables. At present we don't consistently exclude all three in the same places. This commit introduces a new function to gather those exclusion conditions together. Because we have to handle local and remote users in different ways, I've made that function only consider the case of remote users. It's the caller's responsibility to make the local versus remote distinction clear and correct. A test fixup is required. The test now hits a path which makes db queries against the users table. The expected rows were missing, because we were using a dummy user that hadn't actually been registered. We also add new test cases to covert the exclusion logic. ---- By my reading this makes these changes: * When an app service user registers or changes their profile, they will _not_ be added to the user directory. (Previously only support and deactivated users were excluded). This is consistent with the logic that rebuilds the user directory. See also [the discussion here](https://github.com/matrix-org/synapse/pull/10914#discussion_r716859548). * When rebuilding the directory, exclude support and disabled users from room sharing tables. Previously only appservice users were excluded. * Exclude all three categories of local users when rebuilding the directory. Previously `_populate_user_directory_process_users` didn't do any exclusion. Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- changelog.d/10960.bugfix | 1 + synapse/handlers/user_directory.py | 27 +-- synapse/storage/databases/main/user_directory.py | 46 ++++-- tests/handlers/test_user_directory.py | 200 +++++++++++++++++++++-- tests/rest/client/test_login.py | 17 +- tests/storage/test_user_directory.py | 146 ++++++++++++++++- tests/unittest.py | 29 ++++ 7 files changed, 409 insertions(+), 57 deletions(-) create mode 100644 changelog.d/10960.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/10960.bugfix b/changelog.d/10960.bugfix new file mode 100644 index 0000000000..b4f1c228ea --- /dev/null +++ b/changelog.d/10960.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where rebuilding the user directory wouldn't exclude support and disabled users. \ No newline at end of file diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index f4430ce3c9..18d8c8744e 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -132,12 +132,7 @@ class UserDirectoryHandler(StateDeltasHandler): # FIXME(#3714): We should probably do this in the same worker as all # the other changes. - # Support users are for diagnostics and should not appear in the user directory. - is_support = await self.store.is_support_user(user_id) - # When change profile information of deactivated user it should not appear in the user directory. - is_deactivated = await self.store.get_user_deactivated_status(user_id) - - if not (is_support or is_deactivated): + if await self.store.should_include_local_user_in_dir(user_id): await self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) @@ -229,8 +224,10 @@ class UserDirectoryHandler(StateDeltasHandler): else: logger.debug("Server is still in room: %r", room_id) - is_support = await self.store.is_support_user(state_key) - if not is_support: + include_in_dir = not self.is_mine_id( + state_key + ) or await self.store.should_include_local_user_in_dir(state_key) + if include_in_dir: if change is MatchChange.no_change: # Handle any profile changes await self._handle_profile_change( @@ -356,13 +353,7 @@ class UserDirectoryHandler(StateDeltasHandler): # First, if they're our user then we need to update for every user if self.is_mine_id(user_id): - - is_appservice = self.store.get_if_app_services_interested_in_user( - user_id - ) - - # We don't care about appservice users. - if not is_appservice: + if await self.store.should_include_local_user_in_dir(user_id): for other_user_id in other_users_in_room: if user_id == other_user_id: continue @@ -374,10 +365,10 @@ class UserDirectoryHandler(StateDeltasHandler): if user_id == other_user_id: continue - is_appservice = self.store.get_if_app_services_interested_in_user( + include_other_user = self.is_mine_id( other_user_id - ) - if self.is_mine_id(other_user_id) and not is_appservice: + ) and await self.store.should_include_local_user_in_dir(other_user_id) + if include_other_user: to_insert.add((other_user_id, user_id)) if to_insert: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index c26e3e066f..5f538947ec 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -40,12 +40,10 @@ from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) - TEMP_TABLE = "_temp_populate_user_directory" class UserDirectoryBackgroundUpdateStore(StateDeltasStore): - # How many records do we calculate before sending it to # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 @@ -235,6 +233,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) users_with_profile = await self.get_users_in_room_with_profiles(room_id) + # Throw away users excluded from the directory. + users_with_profile = { + user_id: profile + for user_id, profile in users_with_profile.items() + if not self.hs.is_mine_id(user_id) + or await self.should_include_local_user_in_dir(user_id) + } # Update each user in the user directory. for user_id, profile in users_with_profile.items(): @@ -246,9 +251,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): if is_public: for user_id in users_with_profile: - if self.get_if_app_services_interested_in_user(user_id): - continue - to_insert.add(user_id) if to_insert: @@ -256,12 +258,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): to_insert.clear() else: for user_id in users_with_profile: + # We want the set of pairs (L, M) where L and M are + # in `users_with_profile` and L is local. + # Do so by looking for the local user L first. if not self.hs.is_mine_id(user_id): continue - if self.get_if_app_services_interested_in_user(user_id): - continue - for other_user_id in users_with_profile: if user_id == other_user_id: continue @@ -349,10 +351,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) for user_id in users_to_work_on: - profile = await self.get_profileinfo(get_localpart_from_id(user_id)) - await self.update_profile_in_user_dir( - user_id, profile.display_name, profile.avatar_url - ) + if await self.should_include_local_user_in_dir(user_id): + profile = await self.get_profileinfo(get_localpart_from_id(user_id)) + await self.update_profile_in_user_dir( + user_id, profile.display_name, profile.avatar_url + ) # We've finished processing a user. Delete it from the table. await self.db_pool.simple_delete_one( @@ -369,6 +372,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return len(users_to_work_on) + async def should_include_local_user_in_dir(self, user: str) -> bool: + """Certain classes of local user are omitted from the user directory. + Is this user one of them? + """ + # App service users aren't usually contactable, so exclude them. + if self.get_if_app_services_interested_in_user(user): + # TODO we might want to make this configurable for each app service + return False + + # Support users are for diagnostics and should not appear in the user directory. + if await self.is_support_user(user): + return False + + # Deactivated users aren't contactable, so should not appear in the user directory. + if await self.get_user_deactivated_status(user): + return False + return True + async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool: """Check if the room is either world_readable or publically joinable""" @@ -537,7 +558,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): - # How many records do we calculate before sending it to # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 2988befb21..b3c3af113b 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple from unittest.mock import Mock, patch from urllib.parse import quote @@ -20,7 +21,8 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.rest.client import login, room, user_directory +from synapse.appservice import ApplicationService +from synapse.rest.client import login, register, room, user_directory from synapse.server import HomeServer from synapse.storage.roommember import ProfileInfo from synapse.types import create_requester @@ -28,6 +30,7 @@ from synapse.util import Clock from tests import unittest from tests.storage.test_user_directory import GetUserDirectoryTables +from tests.test_utils.event_injection import inject_member_event from tests.unittest import override_config @@ -47,13 +50,29 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, synapse.rest.admin.register_servlets, + register.register_servlets, room.register_servlets, ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["update_user_directory"] = True - return self.setup_test_homeserver(config=config) + + self.appservice = ApplicationService( + token="i_am_an_app_service", + hostname="test", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + hs = self.setup_test_homeserver(config=config) + return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() @@ -62,6 +81,137 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.event_creation_handler = self.hs.get_event_creation_handler() self.user_dir_helper = GetUserDirectoryTables(self.store) + def test_normal_user_pair(self) -> None: + """Sanity check that the room-sharing tables are updated correctly.""" + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + + public = self.helper.create_room_as( + alice, + is_public=True, + extra_content={"visibility": "public"}, + tok=alice_token, + ) + private = self.helper.create_room_as(alice, is_public=False, tok=alice_token) + self.helper.invite(private, alice, bob, tok=alice_token) + self.helper.join(public, bob, tok=bob_token) + self.helper.join(private, bob, tok=bob_token) + + # Alice also makes a second public room but no-one else joins + public2 = self.helper.create_room_as( + alice, + is_public=True, + extra_content={"visibility": "public"}, + tok=alice_token, + ) + + users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) + in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + in_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + + self.assertEqual(users, {alice, bob}) + self.assertEqual( + set(in_public), {(alice, public), (bob, public), (alice, public2)} + ) + self.assertEqual( + self.user_dir_helper._compress_shared(in_private), + {(alice, bob, private), (bob, alice, private)}, + ) + + # The next three tests (test_population_excludes_*) all setup + # - A normal user included in the user dir + # - A public and private room created by that user + # - A user excluded from the room dir, belonging to both rooms + + # They match similar logic in storage/test_user_directory. But that tests + # rebuilding the directory; this tests updating it incrementally. + + def test_excludes_support_user(self) -> None: + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + support = "@support1:test" + self.get_success( + self.store.register_user( + user_id=support, password_hash=None, user_type=UserTypes.SUPPORT + ) + ) + + public, private = self._create_rooms_and_inject_memberships( + alice, alice_token, support + ) + self._check_only_one_user_in_directory(alice, public) + + def test_excludes_deactivated_user(self) -> None: + admin = self.register_user("admin", "pass", admin=True) + admin_token = self.login(admin, "pass") + user = self.register_user("naughty", "pass") + + # Deactivate the user. + channel = self.make_request( + "PUT", + f"/_synapse/admin/v2/users/{user}", + access_token=admin_token, + content={"deactivated": True}, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["deactivated"], True) + + # Join the deactivated user to rooms owned by the admin. + # Is this something that could actually happen outside of a test? + public, private = self._create_rooms_and_inject_memberships( + admin, admin_token, user + ) + self._check_only_one_user_in_directory(admin, public) + + def test_excludes_appservices_user(self) -> None: + # Register an AS user. + user = self.register_user("user", "pass") + token = self.login(user, "pass") + as_user = self.register_appservice_user("as_user_potato", self.appservice.token) + + # Join the AS user to rooms owned by the normal user. + public, private = self._create_rooms_and_inject_memberships( + user, token, as_user + ) + self._check_only_one_user_in_directory(user, public) + + def _create_rooms_and_inject_memberships( + self, creator: str, token: str, joiner: str + ) -> Tuple[str, str]: + """Create a public and private room as a normal user. + Then get the `joiner` into those rooms. + """ + # TODO: Duplicates the same-named method in UserDirectoryInitialPopulationTest. + public_room = self.helper.create_room_as( + creator, + is_public=True, + # See https://github.com/matrix-org/synapse/issues/10951 + extra_content={"visibility": "public"}, + tok=token, + ) + private_room = self.helper.create_room_as(creator, is_public=False, tok=token) + + # HACK: get the user into these rooms + self.get_success(inject_member_event(self.hs, public_room, joiner, "join")) + self.get_success(inject_member_event(self.hs, private_room, joiner, "join")) + + return public_room, private_room + + def _check_only_one_user_in_directory(self, user: str, public: str) -> None: + users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) + in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + in_private = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + + self.assertEqual(users, {user}) + self.assertEqual(set(in_public), {(user, public)}) + self.assertEqual(in_private, []) + def test_handle_local_profile_change_with_support_user(self) -> None: support_user_id = "@support:test" self.get_success( @@ -125,6 +275,26 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): profile = self.get_success(self.store.get_user_in_directory(r_user_id)) self.assertTrue(profile is None) + def test_handle_local_profile_change_with_appservice_user(self) -> None: + # create user + as_user_id = self.register_appservice_user( + "as_user_alice", self.appservice.token + ) + + # profile is not in directory + profile = self.get_success(self.store.get_user_in_directory(as_user_id)) + self.assertTrue(profile is None) + + # update profile + profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3") + self.get_success( + self.handler.handle_local_profile_change(as_user_id, profile_info) + ) + + # profile is still not in directory + profile = self.get_success(self.store.get_user_in_directory(as_user_id)) + self.assertTrue(profile is None) + def test_handle_user_deactivated_support_user(self) -> None: s_user_id = "@support:test" self.get_success( @@ -483,8 +653,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): class TestUserDirSearchDisabled(unittest.HomeserverTestCase): - user_id = "@test:test" - servlets = [ user_directory.register_servlets, room.register_servlets, @@ -504,16 +672,21 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): def test_disabling_room_list(self) -> None: self.config.userdirectory.user_directory_search_enabled = True - # First we create a room with another user so that user dir is non-empty - # for our user - self.helper.create_room_as(self.user_id) + # Create two users and put them in the same room. + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") u2 = self.register_user("user2", "pass") - room = self.helper.create_room_as(self.user_id) - self.helper.join(room, user=u2) + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) - # Assert user directory is not empty + # Each should see the other when searching the user directory. channel = self.make_request( - "POST", b"user_directory/search", b'{"search_term":"user2"}' + "POST", + b"user_directory/search", + b'{"search_term":"user2"}', + access_token=u1_token, ) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) > 0) @@ -521,7 +694,10 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): # Disable user directory and check search returns nothing self.config.userdirectory.user_directory_search_enabled = False channel = self.make_request( - "POST", b"user_directory/search", b'{"search_term":"user2"}' + "POST", + b"user_directory/search", + b'{"search_term":"user2"}', + access_token=u1_token, ) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) == 0) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 7fd92c94e0..a63f04bd41 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -1064,13 +1064,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def register_as_user(self, username): - self.make_request( - b"POST", - "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), - {"username": username}, - ) - def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() @@ -1107,7 +1100,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_user(self): """Test that an appservice user can use /login""" - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, @@ -1121,7 +1114,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_user_bot(self): """Test that the appservice bot can use /login""" - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, @@ -1135,7 +1128,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_wrong_user(self): """Test that non-as users cannot login with the as token""" - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, @@ -1149,7 +1142,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_wrong_as(self): """Test that as users cannot login with wrong as token""" - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, @@ -1165,7 +1158,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): """Test that users must provide a token when using the appservice login method """ - self.register_as_user(AS_USER) + self.register_appservice_user(AS_USER, self.service.token) params = { "type": login.LoginRestServlet.APPSERVICE_TYPE, diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 74c8a8599e..6884ca9b7a 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -12,15 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, Set, Tuple +from unittest.mock import Mock, patch from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import UserTypes +from synapse.appservice import ApplicationService from synapse.rest import admin -from synapse.rest.client import login, room +from synapse.rest.client import login, register, room from synapse.server import HomeServer from synapse.storage import DataStore from synapse.util import Clock +from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config ALICE = "@alice:a" @@ -64,6 +68,14 @@ class GetUserDirectoryTables: ["user_id", "other_user_id", "room_id"], ) + async def get_users_in_user_directory(self) -> Set[str]: + result = await self.store.db_pool.simple_select_list( + "user_directory", + None, + ["user_id"], + ) + return {row["user_id"] for row in result} + class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): """Ensure that rebuilding the directory writes the correct data to the DB. @@ -74,10 +86,28 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): servlets = [ login.register_servlets, - admin.register_servlets_for_client_rest_resource, + admin.register_servlets, room.register_servlets, + register.register_servlets, ] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.appservice = ApplicationService( + token="i_am_an_app_service", + hostname="test", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + hs = super().make_homeserver(reactor, clock) + return hs + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.user_dir_helper = GetUserDirectoryTables(self.store) @@ -204,6 +234,118 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): {(u1, u3, private_room), (u3, u1, private_room)}, ) + # All three should have entries in the directory + users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) + self.assertEqual(users, {u1, u2, u3}) + + # The next three tests (test_population_excludes_*) all set up + # - A normal user included in the user dir + # - A public and private room created by that user + # - A user excluded from the room dir, belonging to both rooms + + # They match similar logic in handlers/test_user_directory.py But that tests + # updating the directory; this tests rebuilding it from scratch. + + def _create_rooms_and_inject_memberships( + self, creator: str, token: str, joiner: str + ) -> Tuple[str, str]: + """Create a public and private room as a normal user. + Then get the `joiner` into those rooms. + """ + public_room = self.helper.create_room_as( + creator, + is_public=True, + # See https://github.com/matrix-org/synapse/issues/10951 + extra_content={"visibility": "public"}, + tok=token, + ) + private_room = self.helper.create_room_as(creator, is_public=False, tok=token) + + # HACK: get the user into these rooms + self.get_success(inject_member_event(self.hs, public_room, joiner, "join")) + self.get_success(inject_member_event(self.hs, private_room, joiner, "join")) + + return public_room, private_room + + def _check_room_sharing_tables( + self, normal_user: str, public_room: str, private_room: str + ) -> None: + # After rebuilding the directory, we should only see the normal user. + users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) + self.assertEqual(users, {normal_user}) + in_public_rooms = self.get_success( + self.user_dir_helper.get_users_in_public_rooms() + ) + self.assertEqual(set(in_public_rooms), {(normal_user, public_room)}) + in_private_rooms = self.get_success( + self.user_dir_helper.get_users_who_share_private_rooms() + ) + self.assertEqual(in_private_rooms, []) + + def test_population_excludes_support_user(self) -> None: + # Create a normal and support user. + user = self.register_user("user", "pass") + token = self.login(user, "pass") + support = "@support1:test" + self.get_success( + self.store.register_user( + user_id=support, password_hash=None, user_type=UserTypes.SUPPORT + ) + ) + + # Join the support user to rooms owned by the normal user. + public, private = self._create_rooms_and_inject_memberships( + user, token, support + ) + + # Rebuild the directory. + self._purge_and_rebuild_user_dir() + + # Check the support user is not in the directory. + self._check_room_sharing_tables(user, public, private) + + def test_population_excludes_deactivated_user(self) -> None: + user = self.register_user("naughty", "pass") + admin = self.register_user("admin", "pass", admin=True) + admin_token = self.login(admin, "pass") + + # Deactivate the user. + channel = self.make_request( + "PUT", + f"/_synapse/admin/v2/users/{user}", + access_token=admin_token, + content={"deactivated": True}, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["deactivated"], True) + + # Join the deactivated user to rooms owned by the admin. + # Is this something that could actually happen outside of a test? + public, private = self._create_rooms_and_inject_memberships( + admin, admin_token, user + ) + + # Rebuild the user dir. The deactivated user should be missing. + self._purge_and_rebuild_user_dir() + self._check_room_sharing_tables(admin, public, private) + + def test_population_excludes_appservice_user(self) -> None: + # Register an AS user. + user = self.register_user("user", "pass") + token = self.login(user, "pass") + as_user = self.register_appservice_user("as_user_potato", self.appservice.token) + + # Join the AS user to rooms owned by the normal user. + public, private = self._create_rooms_and_inject_memberships( + user, token, as_user + ) + + # Rebuild the directory. + self._purge_and_rebuild_user_dir() + + # Check the AS user is not in the directory. + self._check_room_sharing_tables(user, public, private) + class UserDirectoryStoreTestCase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: diff --git a/tests/unittest.py b/tests/unittest.py index 1f803564f6..ae393ee53e 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -596,6 +596,35 @@ class HomeserverTestCase(TestCase): user_id = channel.json_body["user_id"] return user_id + def register_appservice_user( + self, + username: str, + appservice_token: str, + ) -> str: + """Register an appservice user as an application service. + Requires the client-facing registration API be registered. + + Args: + username: the user to be registered by an application service. + Should be a full username, i.e. ""@localpart:hostname" as opposed to just "localpart" + appservice_token: the acccess token for that application service. + + Raises: if the request to '/register' does not return 200 OK. + + Returns: the MXID of the new user. + """ + channel = self.make_request( + "POST", + "/_matrix/client/r0/register", + { + "username": username, + "type": "m.login.application_service", + }, + access_token=appservice_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + return channel.json_body["user_id"] + def login( self, username, -- cgit 1.5.1 From eda8c88b84ee7506379a71ac2a7a88c08b759d43 Mon Sep 17 00:00:00 2001 From: Hillery Shay Date: Mon, 4 Oct 2021 08:34:42 -0700 Subject: Add functionality to remove deactivated users from the monthly_active_users table (#10947) * add test * add function to remove user from monthly active table in deactivate code * add function to remove user from monthly active table * add changelog entry * update changelog number * requested changes * update docstring on new function * fix lint error * Update synapse/storage/databases/main/monthly_active_users.py Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- changelog.d/10947.bugfix | 1 + synapse/handlers/deactivate_account.py | 4 +++ .../storage/databases/main/monthly_active_users.py | 24 ++++++++++++++ tests/test_mau.py | 37 ++++++++++++++++++++-- 4 files changed, 63 insertions(+), 3 deletions(-) create mode 100644 changelog.d/10947.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/10947.bugfix b/changelog.d/10947.bugfix new file mode 100644 index 0000000000..40c70d3ece --- /dev/null +++ b/changelog.d/10947.bugfix @@ -0,0 +1 @@ +Fixes a long-standing bug wherin deactivated users still count towards the mau limit. \ No newline at end of file diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 9ae5b7750e..12bdca7445 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -133,6 +133,10 @@ class DeactivateAccountHandler(BaseHandler): # delete from user directory await self.user_directory_handler.handle_local_user_deactivated(user_id) + # If the user is present in the monthly active users table + # remove them + await self.store.remove_deactivated_user_from_mau_table(user_id) + # Mark the user as erased, if they asked for that if erase_data: user = UserID.from_string(user_id) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index a14ac03d4b..ec4d47a560 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -354,3 +354,27 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): await self.upsert_monthly_active_user(user_id) elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: await self.upsert_monthly_active_user(user_id) + + async def remove_deactivated_user_from_mau_table(self, user_id: str) -> None: + """ + Removes a deactivated user from the monthly active user + table and resets affected caches. + + Args: + user_id(str): the user_id to remove + """ + + rows_deleted = await self.db_pool.simple_delete( + table="monthly_active_users", + keyvalues={"user_id": user_id}, + desc="simple_delete", + ) + + if rows_deleted != 0: + await self.invalidate_cache_and_stream( + "user_last_seen_monthly_active", (user_id,) + ) + await self.invalidate_cache_and_stream("get_monthly_active_count", ()) + await self.invalidate_cache_and_stream( + "get_monthly_active_count_by_service", () + ) diff --git a/tests/test_mau.py b/tests/test_mau.py index 80ab40e255..c683c8937e 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -13,11 +13,11 @@ # limitations under the License. """Tests REST events for /rooms paths.""" - +import synapse.rest.admin from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.appservice import ApplicationService -from synapse.rest.client import register, sync +from synapse.rest.client import login, profile, register, sync from tests import unittest from tests.unittest import override_config @@ -26,7 +26,13 @@ from tests.utils import default_config class TestMauLimit(unittest.HomeserverTestCase): - servlets = [register.register_servlets, sync.register_servlets] + servlets = [ + register.register_servlets, + sync.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + profile.register_servlets, + login.register_servlets, + ] def default_config(self): config = default_config("test") @@ -229,6 +235,31 @@ class TestMauLimit(unittest.HomeserverTestCase): self.reactor.advance(100) self.assertEqual(2, self.successResultOf(count)) + def test_deactivated_users_dont_count_towards_mau(self): + user1 = self.register_user("madonna", "password") + self.register_user("prince", "password2") + self.register_user("frodo", "onering", True) + + token1 = self.login("madonna", "password") + token2 = self.login("prince", "password2") + admin_token = self.login("frodo", "onering") + + self.do_sync_for_user(token1) + self.do_sync_for_user(token2) + + # Check that mau count is what we expect + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 2) + + # Deactivate user1 + url = "/_synapse/admin/v1/deactivate/%s" % user1 + channel = self.make_request("POST", url, access_token=admin_token) + self.assertIn("success", channel.json_body["id_server_unbind_result"]) + + # Check that deactivated user is no longer counted + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 1) + def create_user(self, localpart, token=None, appservice=False): request_data = { "username": localpart, -- cgit 1.5.1 From cb88ed912b3e984e0a409e4e5fd3c22817a4840d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 5 Oct 2021 12:50:07 +0100 Subject: `_check_event_auth`: move event validation earlier (#10988) There's little point in doing a fancy state reconciliation dance if the event itself is invalid. Likewise, there's no point checking it again in `_check_for_soft_fail`. --- changelog.d/10988.misc | 1 + synapse/handlers/federation_event.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10988.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10988.misc b/changelog.d/10988.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10988.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index e587b5b3b3..5938654338 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1250,9 +1250,18 @@ class FederationEventHandler: # This method should only be used for non-outliers assert not event.internal_metadata.outlier + # first of all, check that the event itself is valid. room_version = await self._store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + try: + validate_event_for_room_version(room_version_obj, event) + except AuthError as e: + logger.warning("While validating received event %r: %s", event, e) + # TODO: use a different rejected reason here? + context.rejected = RejectedReason.AUTH_ERROR + return context + # calculate what the auth events *should* be, to use as a basis for auth. prev_state_ids = await context.get_prev_state_ids() auth_events_ids = self._event_auth_handler.compute_auth_events( @@ -1286,7 +1295,6 @@ class FederationEventHandler: auth_events_for_auth = calculated_auth_event_map try: - validate_event_for_room_version(room_version_obj, event) check_auth_rules_for_event(room_version_obj, event, auth_events_for_auth) except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) @@ -1399,9 +1407,6 @@ class FederationEventHandler: } try: - # TODO: skip the call to validate_event_for_room_version? we should already - # have validated the event. - validate_event_for_room_version(room_version_obj, event) check_auth_rules_for_event(room_version_obj, event, current_auth_events) except AuthError as e: logger.warning( -- cgit 1.5.1 From d099535deb5be31891719c61c3757c5150829053 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 5 Oct 2021 12:50:38 +0100 Subject: `_update_auth_events_and_context_for_auth`: add some comments (#10987) Add some more comments about wtf is going on here. --- changelog.d/10987.misc | 1 + synapse/handlers/federation_event.py | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 changelog.d/10987.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10987.misc b/changelog.d/10987.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10987.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 5938654338..aa20d75550 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1476,6 +1476,11 @@ class FederationEventHandler: logger.debug("Events %s are in the store", have_events) missing_auth.difference_update(have_events) + # missing_auth is now the set of event_ids which: + # a. are listed in event.auth_events, *and* + # b. are *not* part of our calculated auth events based on room state, *and* + # c. are *not* yet in our database. + if missing_auth: # If we don't have all the auth events, we need to get them. logger.info("auth_events contains unknown events: %s", missing_auth) @@ -1497,10 +1502,31 @@ class FederationEventHandler: } ) + # auth_events now contains + # 1. our *calculated* auth events based on the room state, plus: + # 2. any events which: + # a. are listed in `event.auth_events`, *and* + # b. are not part of our calculated auth events, *and* + # c. were not in our database before the call to /event_auth + # d. have since been added to our database (most likely by /event_auth). + different_auth = event_auth_events.difference( e.event_id for e in auth_events.values() ) + # different_auth is the set of events which *are* in `event.auth_events`, but + # which are *not* in `auth_events`. Comparing with (2.) above, this means + # exclusively the set of `event.auth_events` which we already had in our + # database before any call to /event_auth. + # + # I'm reasonably sure that the fact that events returned by /event_auth are + # blindly added to auth_events (and hence excluded from different_auth) is a bug + # - though it's a very long-standing one (see + # https://github.com/matrix-org/synapse/commit/78015948a7febb18e000651f72f8f58830a55b93#diff-0bc92da3d703202f5b9be2d3f845e375f5b1a6bc6ba61705a8af9be1121f5e42R786 + # from Jan 2015 which seems to add it, though it actually just moves it from + # elsewhere (before that, it gets lost in a mess of huge "various bug fixes" + # PRs). + if not different_auth: return context, auth_events -- cgit 1.5.1 From 787af4a1062fecc350fa14fe2abfc0e9d2f1555e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 5 Oct 2021 13:01:41 +0100 Subject: Host `cache_joined_hosts_for_event` to caller (#10986) `_check_event_auth` is only called in two places, and only one of those sets `send_on_behalf_of`. Warming the cache isn't really part of auth anyway, so moving it out makes a lot more sense. --- changelog.d/10986.misc | 1 + synapse/handlers/federation_event.py | 18 ++++++++---------- 2 files changed, 9 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10986.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10986.misc b/changelog.d/10986.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10986.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index aa20d75550..9269cb444d 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -356,6 +356,11 @@ class FederationEventHandler: ) # all looks good, we can persist the event. + + # First, precalculate the joined hosts so that the federation sender doesn't + # need to. + await self._event_creation_handler.cache_joined_hosts_for_event(event, context) + await self._run_push_actions_and_persist_event(event, context) return event, context @@ -1299,17 +1304,10 @@ class FederationEventHandler: except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR + return context - if not context.rejected: - await self._check_for_soft_fail(event, state, backfilled, origin=origin) - await self._maybe_kick_guest_users(event) - - # If we are going to send this event over federation we precaclculate - # the joined hosts. - if event.internal_metadata.get_send_on_behalf_of(): - await self._event_creation_handler.cache_joined_hosts_for_event( - event, context - ) + await self._check_for_soft_fail(event, state, backfilled, origin=origin) + await self._maybe_kick_guest_users(event) return context -- cgit 1.5.1 From 392863fbf1ee31f8a1997446ab31919a7b6d9a14 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 5 Oct 2021 11:51:57 -0500 Subject: Fix logic flaw preventing tracking of MSC2716 events in existing room versions (#10962) We correctly allowed using the MSC2716 batch endpoint for the room creator in existing room versions but accidentally didn't track the events because of a logic flaw. This prevented you from connecting subsequent chunks together because it would throw the unknown batch ID error. We only want to process MSC2716 events when: - The room version supports MSC2716 - Any room where the homeserver has the `msc2716_enabled` experimental feature enabled and the event is from the room creator --- changelog.d/10962.bugfix | 1 + synapse/handlers/federation_event.py | 5 ++--- synapse/storage/databases/main/events.py | 10 ++++------ 3 files changed, 7 insertions(+), 9 deletions(-) create mode 100644 changelog.d/10962.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/10962.bugfix b/changelog.d/10962.bugfix new file mode 100644 index 0000000000..9b0760d731 --- /dev/null +++ b/changelog.d/10962.bugfix @@ -0,0 +1 @@ +Fix [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint rejecting subsequent batches with unknown batch ID error in existing room versions from the room creator. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 9269cb444d..243be46267 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1015,9 +1015,8 @@ class FederationEventHandler: room_version = await self._store.get_room_version(marker_event.room_id) create_event = await self._store.get_create_event_for_room(marker_event.room_id) room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) - if ( - not room_version.msc2716_historical - or not self._config.experimental.msc2716_enabled + if not room_version.msc2716_historical and ( + not self._config.experimental.msc2716_enabled or marker_event.sender != room_creator ): return diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index bc7d213fe2..19f55c19c5 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1763,9 +1763,8 @@ class PersistEventsStore: retcol="creator", allow_none=True, ) - if ( - not room_version.msc2716_historical - or not self.hs.config.experimental.msc2716_enabled + if not room_version.msc2716_historical and ( + not self.hs.config.experimental.msc2716_enabled or event.sender != room_creator ): return @@ -1825,9 +1824,8 @@ class PersistEventsStore: retcol="creator", allow_none=True, ) - if ( - not room_version.msc2716_historical - or not self.hs.config.experimental.msc2716_enabled + if not room_version.msc2716_historical and ( + not self.hs.config.experimental.msc2716_enabled or event.sender != room_creator ): return -- cgit 1.5.1 From 370bca32e60a854ab063f1abedb087dacae37e5a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 6 Oct 2021 13:56:45 +0100 Subject: Don't drop user dir deltas when server leaves room (#10982) Fix a long-standing bug where a batch of user directory changes would be silently dropped if the server left a room early in the batch. * Pull out `wait_for_background_update` in tests Co-authored-by: Patrick Cloke Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- changelog.d/10982.bugfix | 1 + synapse/handlers/user_directory.py | 2 +- tests/handlers/test_stats.py | 21 +++-------------- tests/handlers/test_user_directory.py | 39 +++++++++++++++++++++++++++++++ tests/storage/databases/main/test_room.py | 7 +----- tests/storage/test_cleanup_extrems.py | 7 +----- tests/storage/test_client_ips.py | 21 +++-------------- tests/storage/test_event_chain.py | 14 ++--------- tests/storage/test_roommember.py | 14 ++--------- tests/storage/test_user_directory.py | 7 +----- tests/unittest.py | 9 +++++++ 11 files changed, 63 insertions(+), 79 deletions(-) create mode 100644 changelog.d/10982.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/10982.bugfix b/changelog.d/10982.bugfix new file mode 100644 index 0000000000..5c9e15eeaa --- /dev/null +++ b/changelog.d/10982.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where the remainder of a batch of user directory changes would be silently dropped if the server left a room early in the batch. \ No newline at end of file diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 18d8c8744e..97f60b5806 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -220,7 +220,7 @@ class UserDirectoryHandler(StateDeltasHandler): for user_id in user_ids: await self._handle_remove_user(room_id, user_id) - return + continue else: logger.debug("Server is still in room: %r", room_id) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 24b7ef6efc..56207f4db6 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -103,12 +103,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the stats via the background update self._add_background_updates() - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() def test_initial_room(self): """ @@ -140,12 +135,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() r = self.get_success(self.get_all_room_state()) @@ -568,12 +558,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() r1stats_complete = self._get_current_stats("room", r1) u1stats_complete = self._get_current_stats("user", u1) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index b3c3af113b..03fd5a3e2c 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -363,6 +363,45 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) self.assertEqual(s["results"][0]["user_id"], user) + def test_process_join_after_server_leaves_room(self) -> None: + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + + # Alice makes two rooms. Bob joins one of them. + room1 = self.helper.create_room_as(alice, tok=alice_token) + room2 = self.helper.create_room_as(alice, tok=alice_token) + print("room1=", room1) + print("room2=", room2) + self.helper.join(room1, bob, tok=bob_token) + + # The user sharing tables should have been updated. + public1 = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + self.assertEqual(set(public1), {(alice, room1), (alice, room2), (bob, room1)}) + + # Alice leaves room1. The user sharing tables should be updated. + self.helper.leave(room1, alice, tok=alice_token) + public2 = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + self.assertEqual(set(public2), {(alice, room2), (bob, room1)}) + + # Pause the processing of new events. + dir_handler = self.hs.get_user_directory_handler() + dir_handler.update_user_directory = False + + # Bob leaves one room and joins the other. + self.helper.leave(room1, bob, tok=bob_token) + self.helper.join(room2, bob, tok=bob_token) + + # Process the leave and join in one go. + dir_handler.update_user_directory = True + dir_handler.notify_new_event() + self.wait_for_background_updates() + + # The user sharing tables should have been updated. + public3 = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + self.assertEqual(set(public3), {(alice, room2), (bob, room2)}) + def test_private_room(self) -> None: """ A user can be searched for only by people that are either in a public diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index ffee707153..7496974da3 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -79,12 +79,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() # Make sure the background update filled in the room creator room_creator_after = self.get_success( diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 7cc5e621ba..a59c28f896 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -66,12 +66,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Ugh, have to reset this flag self.store.db_pool.updates._all_done = False - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() def test_soft_failed_extremities_handled_correctly(self): """Test that extremities are correctly calculated in the presence of diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 3cc8038f1e..dada4f98c9 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -242,12 +242,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_devices_last_seen_bg_update(self): # First make sure we have completed all updates. - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() user_id = "@user:id" device_id = "MY_DEVICE" @@ -311,12 +306,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() # We should now get the correct result again result = self.get_success( @@ -337,12 +327,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_old_user_ips_pruned(self): # First make sure we have completed all updates. - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() user_id = "@user:id" device_id = "MY_DEVICE" diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 93136f0717..b31c5eb5ec 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -578,12 +578,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): # Ugh, have to reset this flag self.store.db_pool.updates._all_done = False - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() # Test that the `has_auth_chain_index` has been set self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) @@ -619,12 +614,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): # Ugh, have to reset this flag self.store.db_pool.updates._all_done = False - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() # Test that the `has_auth_chain_index` has been set self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index c72dc40510..2873e22ccf 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -169,12 +169,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def test_can_rerun_update(self): # First make sure we have completed all updates. - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() # Now let's create a room, which will insert a membership user = UserID("alice", "test") @@ -197,9 +192,4 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index fddfb8db28..9f483ad681 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -212,12 +212,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): ) ) - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) + self.wait_for_background_updates() def test_initial(self) -> None: """ diff --git a/tests/unittest.py b/tests/unittest.py index ae393ee53e..81c1a9e9d2 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -317,6 +317,15 @@ class HomeserverTestCase(TestCase): self.reactor.advance(0.01) time.sleep(0.01) + def wait_for_background_updates(self) -> None: + """Block until all background database updates have completed.""" + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + def make_homeserver(self, reactor, clock): """ Make and return a homeserver. -- cgit 1.5.1 From 829f2a82b042d944fef3df55faec924502cdf20d Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 6 Oct 2021 16:32:16 +0200 Subject: Add a spamchecker callback to allow or deny room joins (#10910) Co-authored-by: Erik Johnston --- changelog.d/10910.feature | 1 + docs/modules/spam_checker_callbacks.md | 15 +++++ synapse/events/spamcheck.py | 24 ++++++++ synapse/handlers/room.py | 2 + synapse/handlers/room_member.py | 31 ++++++++++ tests/rest/client/test_rooms.py | 101 +++++++++++++++++++++++++++++++++ 6 files changed, 174 insertions(+) create mode 100644 changelog.d/10910.feature (limited to 'synapse/handlers') diff --git a/changelog.d/10910.feature b/changelog.d/10910.feature new file mode 100644 index 0000000000..aee139f8b6 --- /dev/null +++ b/changelog.d/10910.feature @@ -0,0 +1 @@ +Add a spam checker callback to allow or deny room joins. diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 7920ac5f8f..92376df993 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -19,6 +19,21 @@ either a `bool` to indicate whether the event must be rejected because of spam, to indicate the event must be rejected because of spam and to give a rejection reason to forward to clients. +### `user_may_join_room` + +```python +async def user_may_join_room(user: str, room: str, is_invited: bool) -> bool +``` + +Called when a user is trying to join a room. The module must return a `bool` to indicate +whether the user can join the room. The user is represented by their Matrix user ID (e.g. +`@alice:example.com`) and the room is represented by its Matrix ID (e.g. +`!room:example.com`). The module is also given a boolean to indicate whether the user +currently has a pending invite in the room. + +This callback isn't called if the join is performed by a server administrator, or in the +context of a room creation. + ### `user_may_invite` ```python diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index c389f70b8d..ec8863e397 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -44,6 +44,7 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[ ["synapse.events.EventBase"], Awaitable[Union[bool, str]], ] +USER_MAY_JOIN_ROOM_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]] USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]] USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK = Callable[ @@ -165,6 +166,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): class SpamChecker: def __init__(self): self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = [] + self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = [] self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = [] self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = [] self._user_may_create_room_with_invites_callbacks: List[ @@ -187,6 +189,7 @@ class SpamChecker: def register_callbacks( self, check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None, + user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None, user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, user_may_create_room_with_invites: Optional[ @@ -206,6 +209,9 @@ class SpamChecker: if check_event_for_spam is not None: self._check_event_for_spam_callbacks.append(check_event_for_spam) + if user_may_join_room is not None: + self._user_may_join_room_callbacks.append(user_may_join_room) + if user_may_invite is not None: self._user_may_invite_callbacks.append(user_may_invite) @@ -259,6 +265,24 @@ class SpamChecker: return False + async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool): + """Checks if a given users is allowed to join a room. + Not called when a user creates a room. + + Args: + userid: The ID of the user wanting to join the room + room_id: The ID of the room the user wants to join + is_invited: Whether the user is invited into the room + + Returns: + bool: Whether the user may join the room + """ + for callback in self._user_may_join_room_callbacks: + if await callback(user_id, room_id, is_invited) is False: + return False + + return True + async def user_may_invite( self, inviter_userid: str, invitee_userid: str, room_id: str ) -> bool: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 873e08258e..d40dbd761d 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -860,6 +860,7 @@ class RoomCreationHandler(BaseHandler): "invite", ratelimit=False, content=content, + new_room=True, ) for invite_3pid in invite_3pid_list: @@ -962,6 +963,7 @@ class RoomCreationHandler(BaseHandler): "join", ratelimit=ratelimit, content=creator_join_profile, + new_room=True, ) # We treat the power levels override specially as this needs to be one diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index c8fb24a20c..0b79dbcf8d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -434,6 +434,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed: Optional[dict] = None, ratelimit: bool = True, content: Optional[dict] = None, + new_room: bool = False, require_consent: bool = True, outlier: bool = False, prev_event_ids: Optional[List[str]] = None, @@ -451,6 +452,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed: Information from a 3PID invite. ratelimit: Whether to rate limit the request. content: The content of the created event. + new_room: Whether the membership update is happening in the context of a room + creation. require_consent: Whether consent is required. outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as @@ -485,6 +488,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed=third_party_signed, ratelimit=ratelimit, content=content, + new_room=new_room, require_consent=require_consent, outlier=outlier, prev_event_ids=prev_event_ids, @@ -504,6 +508,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed: Optional[dict] = None, ratelimit: bool = True, content: Optional[dict] = None, + new_room: bool = False, require_consent: bool = True, outlier: bool = False, prev_event_ids: Optional[List[str]] = None, @@ -523,6 +528,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed: ratelimit: content: + new_room: Whether the membership update is happening in the context of a room + creation. require_consent: outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as @@ -726,6 +733,30 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") + # Figure out whether the user is a server admin to determine whether they + # should be able to bypass the spam checker. + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): + # allow the server notices mxid to join rooms + bypass_spam_checker = True + + else: + bypass_spam_checker = await self.auth.is_server_admin(requester.user) + + inviter = await self._get_inviter(target.to_string(), room_id) + if ( + not bypass_spam_checker + # We assume that if the spam checker allowed the user to create + # a room then they're allowed to join it. + and not new_room + and not await self.spam_checker.user_may_join_room( + target.to_string(), room_id, is_invited=inviter is not None + ) + ): + raise SynapseError(403, "Not allowed to join this room") + # Check if a remote join should be performed. remote_join, remote_room_hosts = await self._should_perform_remote_join( target.to_string(), room_id, remote_room_hosts, content, is_host_in_room diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 30bdaa9c27..a41ec6a98f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -784,6 +784,30 @@ class RoomsCreateTestCase(RoomBase): # Check that do_3pid_invite wasn't called this time. self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids)) + def test_spam_checker_may_join_room(self): + """Tests that the user_may_join_room spam checker callback is correctly bypassed + when creating a new room. + """ + + async def user_may_join_room( + mxid: str, + room_id: str, + is_invite: bool, + ) -> bool: + return False + + join_mock = Mock(side_effect=user_may_join_room) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock) + + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + self.assertEquals(channel.code, 200, channel.json_body) + + self.assertEquals(join_mock.call_count, 0) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -975,6 +999,83 @@ class RoomInviteRatelimitTestCase(RoomBase): self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429) +class RoomJoinTestCase(RoomBase): + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.user1 = self.register_user("thomas", "hackme") + self.tok1 = self.login("thomas", "hackme") + + self.user2 = self.register_user("teresa", "hackme") + self.tok2 = self.login("teresa", "hackme") + + self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) + self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) + self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) + + def test_spam_checker_may_join_room(self): + """Tests that the user_may_join_room spam checker callback is correctly called + and blocks room joins when needed. + """ + + # Register a dummy callback. Make it allow all room joins for now. + return_value = True + + async def user_may_join_room( + userid: str, + room_id: str, + is_invited: bool, + ) -> bool: + return return_value + + callback_mock = Mock(side_effect=user_may_join_room) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) + + # Join a first room, without being invited to it. + self.helper.join(self.room1, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room1, + False, + ), + ) + self.assertEquals( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Join a second room, this time with an invite for it. + self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1) + self.helper.join(self.room2, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room2, + True, + ), + ) + self.assertEquals( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Now make the callback deny all room joins, and check that a join actually fails. + return_value = False + self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + + class RoomJoinRatelimitTestCase(RoomBase): user_id = "@sid1:red" -- cgit 1.5.1 From f4b1a9a527273ef71b2f7d970642b7af45462e0f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 6 Oct 2021 10:47:41 -0400 Subject: Require direct references to configuration variables. (#10985) This removes the magic allowing accessing configurable variables directly from the config object. It is now required that a specific configuration class is used (e.g. `config.foo` must be replaced with `config.server.foo`). --- changelog.d/10985.misc | 1 + scripts/synapse_port_db | 4 +- scripts/update_synapse_database | 2 +- synapse/app/_base.py | 2 +- synapse/app/admin_cmd.py | 4 +- synapse/app/homeserver.py | 2 +- synapse/config/_base.py | 64 ++++---------------------- synapse/config/account_validity.py | 2 +- synapse/config/cas.py | 2 +- synapse/config/emailconfig.py | 9 ++-- synapse/config/key.py | 6 ++- synapse/config/oidc.py | 2 +- synapse/config/registration.py | 7 ++- synapse/config/repository.py | 2 +- synapse/config/saml2.py | 2 +- synapse/config/server_notices.py | 4 +- synapse/config/sso.py | 6 ++- synapse/handlers/account_validity.py | 8 +--- synapse/handlers/room_member.py | 7 ++- synapse/replication/tcp/client.py | 2 +- synapse/replication/tcp/handler.py | 7 ++- synapse/rest/client/auth.py | 2 +- synapse/rest/client/push_rule.py | 4 +- synapse/storage/databases/main/push_rule.py | 4 +- synapse/storage/databases/main/registration.py | 4 +- tests/config/test_base.py | 21 +++++---- tests/config/test_cache.py | 50 ++++++++------------ tests/config/test_load.py | 12 +++-- tests/config/test_tls.py | 38 +++++++-------- tests/storage/test_appservice.py | 2 +- tests/storage/test_txn_limit.py | 2 +- 31 files changed, 124 insertions(+), 160 deletions(-) create mode 100644 changelog.d/10985.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10985.misc b/changelog.d/10985.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10985.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index fa6ac6d93a..a947d9e49e 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -215,7 +215,7 @@ class MockHomeserver: def __init__(self, config): self.clock = Clock(reactor) self.config = config - self.hostname = config.server_name + self.hostname = config.server.server_name self.version_string = "Synapse/" + get_version_string(synapse) def get_clock(self): @@ -583,7 +583,7 @@ class Porter(object): return self.postgres_store = self.build_db_store( - self.hs_config.get_single_database() + self.hs_config.database.get_single_database() ) await self.run_background_updates_on_postgres() diff --git a/scripts/update_synapse_database b/scripts/update_synapse_database index 26b29b0b45..6c088bad93 100755 --- a/scripts/update_synapse_database +++ b/scripts/update_synapse_database @@ -36,7 +36,7 @@ class MockHomeserver(HomeServer): def __init__(self, config, **kwargs): super(MockHomeserver, self).__init__( - config.server_name, reactor=reactor, config=config, **kwargs + config.server.server_name, reactor=reactor, config=config, **kwargs ) self.version_string = "Synapse/" + get_version_string(synapse) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 749bc1deb9..4a204a5823 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -301,7 +301,7 @@ def refresh_certificate(hs): if not hs.config.server.has_tls_listener(): return - hs.config.read_certificate_from_disk() + hs.config.tls.read_certificate_from_disk() hs.tls_server_context_factory = context_factory.ServerContextFactory(hs.config) if hs._listening_services: diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 556bcc124e..13d20af457 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -197,9 +197,9 @@ def start(config_options): # Explicitly disable background processes config.server.update_user_directory = False config.worker.run_background_tasks = False - config.start_pushers = False + config.worker.start_pushers = False config.pusher_shard_config.instances = [] - config.send_federation = False + config.worker.send_federation = False config.federation_shard_config.instances = [] synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2b2d4bbf83..422f03cc04 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -234,7 +234,7 @@ class SynapseHomeServer(HomeServer): ) if name in ["media", "federation", "client"]: - if self.config.media.enable_media_repo: + if self.config.server.enable_media_repo: media_repo = self.get_media_repository_resource() resources.update( {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo} diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 26152b0924..7c4428a138 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -118,21 +118,6 @@ class Config: "synapse", "res/templates" ) - def __getattr__(self, item: str) -> Any: - """ - Try and fetch a configuration option that does not exist on this class. - - This is so that existing configs that rely on `self.value`, where value - is actually from a different config section, continue to work. - """ - if item in ["generate_config_section", "read_config"]: - raise AttributeError(item) - - if self.root is None: - raise AttributeError(item) - else: - return self.root._get_unclassed_config(self.section, item) - @staticmethod def parse_size(value): if isinstance(value, int): @@ -289,7 +274,9 @@ class Config: env.filters.update( { "format_ts": _format_ts_filter, - "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + "mxc_to_http": _create_mxc_to_http_filter( + self.root.server.public_baseurl + ), } ) @@ -311,8 +298,6 @@ class RootConfig: config_classes = [] def __init__(self): - self._configs = OrderedDict() - for config_class in self.config_classes: if config_class.section is None: raise ValueError("%r requires a section name" % (config_class,)) @@ -321,42 +306,7 @@ class RootConfig: conf = config_class(self) except Exception as e: raise Exception("Failed making %s: %r" % (config_class.section, e)) - self._configs[config_class.section] = conf - - def __getattr__(self, item: str) -> Any: - """ - Redirect lookups on this object either to config objects, or values on - config objects, so that `config.tls.blah` works, as well as legacy uses - of things like `config.server.server_name`. It will first look up the config - section name, and then values on those config classes. - """ - if item in self._configs.keys(): - return self._configs[item] - - return self._get_unclassed_config(None, item) - - def _get_unclassed_config(self, asking_section: Optional[str], item: str): - """ - Fetch a config value from one of the instantiated config classes that - has not been fetched directly. - - Args: - asking_section: If this check is coming from a Config child, which - one? This section will not be asked if it has the value. - item: The configuration value key. - - Raises: - AttributeError if no config classes have the config key. The body - will contain what sections were checked. - """ - for key, val in self._configs.items(): - if key == asking_section: - continue - - if item in dir(val): - return getattr(val, item) - - raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),)) + setattr(self, config_class.section, conf) def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]: """ @@ -373,9 +323,11 @@ class RootConfig: """ res = OrderedDict() - for name, config in self._configs.items(): + for config_class in self.config_classes: + config = getattr(self, config_class.section) + if hasattr(config, func_name): - res[name] = getattr(config, func_name)(*args, **kwargs) + res[config_class.section] = getattr(config, func_name)(*args, **kwargs) return res diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index ffaffc4931..b56c2a24df 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -76,7 +76,7 @@ class AccountValidityConfig(Config): ) if self.account_validity_renew_by_email_enabled: - if not self.public_baseurl: + if not self.root.server.public_baseurl: raise ConfigError("Can't send renewal emails without 'public_baseurl'") # Load account validity templates. diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 901f4123e1..9b58ecf3d8 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -37,7 +37,7 @@ class CasConfig(Config): # The public baseurl is required because it is used by the redirect # template. - public_baseurl = self.public_baseurl + public_baseurl = self.root.server.public_baseurl if not public_baseurl: raise ConfigError("cas_config requires a public_baseurl to be set") diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 936abe6178..8ff59aa2f8 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -19,7 +19,6 @@ import email.utils import logging import os from enum import Enum -from typing import Optional import attr @@ -135,7 +134,7 @@ class EmailConfig(Config): # msisdn is currently always remote while Synapse does not support any method of # sending SMS messages ThreepidBehaviour.REMOTE - if self.account_threepid_delegate_email + if self.root.registration.account_threepid_delegate_email else ThreepidBehaviour.LOCAL ) # Prior to Synapse v1.4.0, there was another option that defined whether Synapse would @@ -144,7 +143,7 @@ class EmailConfig(Config): # identity server in the process. self.using_identity_server_from_trusted_list = False if ( - not self.account_threepid_delegate_email + not self.root.registration.account_threepid_delegate_email and config.get("trust_identity_server_for_password_resets", False) is True ): # Use the first entry in self.trusted_third_party_id_servers instead @@ -156,7 +155,7 @@ class EmailConfig(Config): # trusted_third_party_id_servers does not contain a scheme whereas # account_threepid_delegate_email is expected to. Presume https - self.account_threepid_delegate_email: Optional[str] = ( + self.root.registration.account_threepid_delegate_email = ( "https://" + first_trusted_identity_server ) self.using_identity_server_from_trusted_list = True @@ -335,7 +334,7 @@ class EmailConfig(Config): "client_base_url", email_config.get("riot_base_url", None) ) - if self.account_validity_renew_by_email_enabled: + if self.root.account_validity.account_validity_renew_by_email_enabled: expiry_template_html = email_config.get( "expiry_template_html", "notice_expiry.html" ) diff --git a/synapse/config/key.py b/synapse/config/key.py index 94a9063043..015dbb8a67 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -145,11 +145,13 @@ class KeyConfig(Config): # list of TrustedKeyServer objects self.key_servers = list( - _parse_key_servers(key_servers, self.federation_verify_certificates) + _parse_key_servers( + key_servers, self.root.tls.federation_verify_certificates + ) ) self.macaroon_secret_key = config.get( - "macaroon_secret_key", self.registration_shared_secret + "macaroon_secret_key", self.root.registration.registration_shared_secret ) if not self.macaroon_secret_key: diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 7e67fbada1..10f5796330 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -58,7 +58,7 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - public_baseurl = self.public_baseurl + public_baseurl = self.root.server.public_baseurl if public_baseurl is None: raise ConfigError("oidc_config requires a public_baseurl to be set") self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback" diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 7cffdacfa5..a3d2a38c4c 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -45,7 +45,10 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") - if self.account_threepid_delegate_msisdn and not self.public_baseurl: + if ( + self.account_threepid_delegate_msisdn + and not self.root.server.public_baseurl + ): raise ConfigError( "The configuration option `public_baseurl` is required if " "`account_threepid_delegate.msisdn` is set, such that " @@ -85,7 +88,7 @@ class RegistrationConfig(Config): if mxid_localpart: # Convert the localpart to a full mxid. self.auto_join_user_id = UserID( - mxid_localpart, self.server_name + mxid_localpart, self.root.server.server_name ).to_string() if self.autocreate_auto_join_rooms: diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 7481f3bf5f..69906a98d4 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -94,7 +94,7 @@ class ContentRepositoryConfig(Config): # Only enable the media repo if either the media repo is enabled or the # current worker app is the media repo. if ( - self.enable_media_repo is False + self.root.server.enable_media_repo is False and config.get("worker_app") != "synapse.app.media_repository" ): self.can_load_media_repo = False diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 05e983625d..9c51b6a25a 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -199,7 +199,7 @@ class SAML2Config(Config): """ import saml2 - public_baseurl = self.public_baseurl + public_baseurl = self.root.server.public_baseurl if public_baseurl is None: raise ConfigError("saml2_config requires a public_baseurl to be set") diff --git a/synapse/config/server_notices.py b/synapse/config/server_notices.py index 48bf3241b6..bde4e879d9 100644 --- a/synapse/config/server_notices.py +++ b/synapse/config/server_notices.py @@ -73,7 +73,9 @@ class ServerNoticesConfig(Config): return mxid_localpart = c["system_mxid_localpart"] - self.server_notices_mxid = UserID(mxid_localpart, self.server_name).to_string() + self.server_notices_mxid = UserID( + mxid_localpart, self.root.server.server_name + ).to_string() self.server_notices_mxid_display_name = c.get("system_mxid_display_name", None) self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None) # todo: i18n diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 524a7ff3aa..11a9b76aa0 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -103,8 +103,10 @@ class SSOConfig(Config): # the client's. # public_baseurl is an optional setting, so we only add the fallback's URL to the # list if it's provided (because we can't figure out what that URL is otherwise). - if self.public_baseurl: - login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + if self.root.server.public_baseurl: + login_fallback_url = ( + self.root.server.public_baseurl + "_matrix/static/client/login" + ) self.sso_client_whitelist.append(login_fallback_url) def generate_config_section(self, **kwargs): diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 5a5f124ddf..87e415df75 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -67,12 +67,8 @@ class AccountValidityHandler: and self._account_validity_renew_by_email_enabled ): # Don't do email-specific configuration if renewal by email is disabled. - self._template_html = ( - hs.config.account_validity.account_validity_template_html - ) - self._template_text = ( - hs.config.account_validity.account_validity_template_text - ) + self._template_html = hs.config.email.account_validity_template_html + self._template_text = hs.config.email.account_validity_template_text self._renew_email_subject = ( hs.config.account_validity.account_validity_renew_email_subject ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0b79dbcf8d..c05461bf2a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1499,8 +1499,11 @@ class RoomMemberMasterHandler(RoomMemberHandler): if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") - check_complexity = self.hs.config.limit_remote_rooms.enabled - if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join: + check_complexity = self.hs.config.server.limit_remote_rooms.enabled + if ( + check_complexity + and self.hs.config.server.limit_remote_rooms.admins_can_join + ): check_complexity = not await self.auth.is_server_admin(user) if check_complexity: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 37769ace48..961c17762e 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -117,7 +117,7 @@ class ReplicationDataHandler: self._instance_name = hs.get_instance_name() self._typing_handler = hs.get_typing_handler() - self._notify_pushers = hs.config.start_pushers + self._notify_pushers = hs.config.worker.start_pushers self._pusher_pool = hs.get_pusherpool() self._presence_handler = hs.get_presence_handler() diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index d64d1dbacd..6aa9318027 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -171,7 +171,10 @@ class ReplicationCommandHandler: if hs.config.worker.worker_app is not None: continue - if stream.NAME == FederationStream.NAME and hs.config.send_federation: + if ( + stream.NAME == FederationStream.NAME + and hs.config.worker.send_federation + ): # We only support federation stream if federation sending # has been disabled on the master. continue @@ -225,7 +228,7 @@ class ReplicationCommandHandler: self._is_master = hs.config.worker.worker_app is None self._federation_sender = None - if self._is_master and not hs.config.send_federation: + if self._is_master and not hs.config.worker.send_federation: self._federation_sender = hs.get_federation_sender() self._server_notices_sender = None diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index c9ad35a3ad..9c15a04338 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -48,7 +48,7 @@ class AuthRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.captcha.recaptcha_template - self.terms_template = hs.config.terms_template + self.terms_template = hs.config.consent.terms_template self.registration_token_template = ( hs.config.registration.registration_token_template ) diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index ecebc46e8d..6f796d5e50 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -61,7 +61,9 @@ class PushRuleRestServlet(RestServlet): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None - self._users_new_default_push_rules = hs.config.users_new_default_push_rules + self._users_new_default_push_rules = ( + hs.config.server.users_new_default_push_rules + ) async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index a7fb8cd848..b81e33964a 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -101,7 +101,9 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) - self._users_new_default_push_rules = hs.config.users_new_default_push_rules + self._users_new_default_push_rules = ( + hs.config.server.users_new_default_push_rules + ) @abc.abstractmethod def get_max_push_rules_stream_id(self): diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index de262fbf5a..7de4ad7f9b 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1778,7 +1778,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors + self._ignore_unknown_session_error = ( + hs.config.server.request_token_inhibit_3pid_errors + ) self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") diff --git a/tests/config/test_base.py b/tests/config/test_base.py index baa5313fb3..6a52f862f4 100644 --- a/tests/config/test_base.py +++ b/tests/config/test_base.py @@ -14,23 +14,28 @@ import os.path import tempfile +from unittest.mock import Mock from synapse.config import ConfigError +from synapse.config._base import Config from synapse.util.stringutils import random_string from tests import unittest -class BaseConfigTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): - self.hs = hs +class BaseConfigTestCase(unittest.TestCase): + def setUp(self): + # The root object needs a server property with a public_baseurl. + root = Mock() + root.server.public_baseurl = "http://test" + self.config = Config(root) def test_loading_missing_templates(self): # Use a temporary directory that exists on the system, but that isn't likely to # contain template files with tempfile.TemporaryDirectory() as tmp_dir: # Attempt to load an HTML template from our custom template directory - template = self.hs.config.read_templates(["sso_error.html"], (tmp_dir,))[0] + template = self.config.read_templates(["sso_error.html"], (tmp_dir,))[0] # If no errors, we should've gotten the default template instead @@ -60,7 +65,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Attempt to load the template from our custom template directory template = ( - self.hs.config.read_templates([template_filename], (tmp_dir,)) + self.config.read_templates([template_filename], (tmp_dir,)) )[0] # Render the template @@ -97,7 +102,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Retrieve the template. template = ( - self.hs.config.read_templates( + self.config.read_templates( [template_filename], (td.name for td in tempdirs), ) @@ -118,7 +123,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Retrieve the template. template = ( - self.hs.config.read_templates( + self.config.read_templates( [other_template_name], (td.name for td in tempdirs), ) @@ -134,6 +139,6 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): def test_loading_template_from_nonexistent_custom_directory(self): with self.assertRaises(ConfigError): - self.hs.config.read_templates( + self.config.read_templates( ["some_filename.html"], ("a_nonexistent_directory",) ) diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index 857d9cd096..f518abdb7a 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -12,39 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.config._base import Config, RootConfig from synapse.config.cache import CacheConfig, add_resizable_cache from synapse.util.caches.lrucache import LruCache from tests.unittest import TestCase -class FakeServer(Config): - section = "server" - - -class TestConfig(RootConfig): - config_classes = [FakeServer, CacheConfig] - - class CacheConfigTests(TestCase): def setUp(self): # Reset caches before each test - TestConfig().caches.reset() + self.config = CacheConfig() + + def tearDown(self): + self.config.reset() def test_individual_caches_from_environ(self): """ Individual cache factors will be loaded from the environment. """ config = {} - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_NOT_CACHE": "BLAH", } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0}) + self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0}) def test_config_overrides_environ(self): """ @@ -52,15 +45,14 @@ class CacheConfigTests(TestCase): over those in the config. """ config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_CACHE_FACTOR_FOO": 1, } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual( - dict(t.caches.cache_factors), + dict(self.config.cache_factors), {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, ) @@ -76,8 +68,7 @@ class CacheConfigTests(TestCase): self.assertEqual(cache.max_size, 50) config = {"caches": {"per_cache_factors": {"foo": 3}}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config) self.assertEqual(cache.max_size, 300) @@ -88,8 +79,7 @@ class CacheConfigTests(TestCase): there is one. """ config = {"caches": {"per_cache_factors": {"foo": 2}}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -106,8 +96,7 @@ class CacheConfigTests(TestCase): self.assertEqual(cache.max_size, 50) config = {"caches": {"global_factor": 4}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual(cache.max_size, 400) @@ -118,8 +107,7 @@ class CacheConfigTests(TestCase): is no per-cache factor. """ config = {"caches": {"global_factor": 1.5}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -133,12 +121,11 @@ class CacheConfigTests(TestCase): "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2} } } - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_CACHE_A": "2", "SYNAPSE_CACHE_FACTOR_CACHE_B": 3, } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache_a = LruCache(100) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) @@ -158,11 +145,10 @@ class CacheConfigTests(TestCase): """ config = {"caches": {"event_cache_size": "10k"}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache( - max_size=t.caches.event_cache_size, + max_size=self.config.event_cache_size, apply_cache_factor_from_config=False, ) add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 8e49ca26d9..59635de205 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -49,7 +49,7 @@ class ConfigLoadingTestCase(unittest.TestCase): config = HomeServerConfig.load_config("", ["-c", self.file]) self.assertTrue( - hasattr(config, "macaroon_secret_key"), + hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) if len(config.key.macaroon_secret_key) < 5: @@ -60,7 +60,7 @@ class ConfigLoadingTestCase(unittest.TestCase): config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) self.assertTrue( - hasattr(config, "macaroon_secret_key"), + hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) if len(config.key.macaroon_secret_key) < 5: @@ -74,8 +74,12 @@ class ConfigLoadingTestCase(unittest.TestCase): config1 = HomeServerConfig.load_config("", ["-c", self.file]) config2 = HomeServerConfig.load_config("", ["-c", self.file]) config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) - self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key) - self.assertEqual(config1.macaroon_secret_key, config3.macaroon_secret_key) + self.assertEqual( + config1.key.macaroon_secret_key, config2.key.macaroon_secret_key + ) + self.assertEqual( + config1.key.macaroon_secret_key, config3.key.macaroon_secret_key + ) def test_disable_registration(self): self.generate_config() diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index b6bc1876b5..9ba5781573 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -42,9 +42,9 @@ class TLSConfigTests(TestCase): """ config = {} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") def test_tls_client_minimum_set(self): """ @@ -52,29 +52,29 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") config = {"federation_client_minimum_tls_version": 1.1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.1") config = {"federation_client_minimum_tls_version": 1.2} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.2") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2") # Also test a string version config = {"federation_client_minimum_tls_version": "1"} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") config = {"federation_client_minimum_tls_version": "1.2"} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.2") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2") def test_tls_client_minimum_1_point_3_missing(self): """ @@ -91,7 +91,7 @@ class TLSConfigTests(TestCase): config = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() with self.assertRaises(ConfigError) as e: - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual( e.exception.args[0], ( @@ -112,8 +112,8 @@ class TLSConfigTests(TestCase): config = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.3") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3") def test_tls_client_minimum_set_passed_through_1_2(self): """ @@ -121,7 +121,7 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1.2} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) options = _get_ssl_context_options(cf._verify_ssl_context) @@ -137,7 +137,7 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) options = _get_ssl_context_options(cf._verify_ssl_context) @@ -159,7 +159,7 @@ class TLSConfigTests(TestCase): } t = TestConfig() e = self.assertRaises( - ConfigError, t.read_config, config, config_dir_path="", data_dir_path="" + ConfigError, t.tls.read_config, config, config_dir_path="", data_dir_path="" ) self.assertIn("IDNA domain names", str(e)) @@ -174,7 +174,7 @@ class TLSConfigTests(TestCase): ] } t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index cf9748f218..f26d5acf9c 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -126,7 +126,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.db_pool = database._db_pool self.engine = database.engine - db_config = hs.config.get_single_database() + db_config = hs.config.database.get_single_database() self.store = TestTransactionStore( database, make_conn(db_config, self.engine, "test"), hs ) diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py index 6ff3ebb137..ace82cbf42 100644 --- a/tests/storage/test_txn_limit.py +++ b/tests/storage/test_txn_limit.py @@ -22,7 +22,7 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase): return self.setup_test_homeserver(db_txn_limit=1000) def test_config(self): - db_config = self.hs.config.get_single_database() + db_config = self.hs.config.database.get_single_database() self.assertEqual(db_config.config["txn_limit"], 1000) def test_select(self): -- cgit 1.5.1 From 4e5162106436f3fddd12561d316d19fd23148800 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 6 Oct 2021 17:18:13 +0200 Subject: Add a spamchecker method to allow or deny 3pid invites (#10894) This is in the context of creating new module callbacks that modules in https://github.com/matrix-org/synapse-dinsic can use, in an effort to reconcile the spam checker API in synapse-dinsic with the one in mainline. Note that a module callback already exists for 3pid invites (https://matrix-org.github.io/synapse/develop/modules/third_party_rules_callbacks.html#check_threepid_can_be_invited) but it doesn't check whether the sender of the invite is allowed to send it. --- changelog.d/10894.feature | 1 + docs/modules/spam_checker_callbacks.md | 35 +++++++++++++++++ synapse/events/spamcheck.py | 35 +++++++++++++++++ synapse/handlers/room_member.py | 12 ++++++ tests/rest/client/test_rooms.py | 70 ++++++++++++++++++++++++++++++++++ 5 files changed, 153 insertions(+) create mode 100644 changelog.d/10894.feature (limited to 'synapse/handlers') diff --git a/changelog.d/10894.feature b/changelog.d/10894.feature new file mode 100644 index 0000000000..a4f968bed1 --- /dev/null +++ b/changelog.d/10894.feature @@ -0,0 +1 @@ +Add a `user_may_send_3pid_invite` spam checker callback for modules to allow or deny 3PID invites. diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 92376df993..787e99074a 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -44,6 +44,41 @@ Called when processing an invitation. The module must return a `bool` indicating the inviter can invite the invitee to the given room. Both inviter and invitee are represented by their Matrix user ID (e.g. `@alice:example.com`). +### `user_may_send_3pid_invite` + +```python +async def user_may_send_3pid_invite( + inviter: str, + medium: str, + address: str, + room_id: str, +) -> bool +``` + +Called when processing an invitation using a third-party identifier (also called a 3PID, +e.g. an email address or a phone number). The module must return a `bool` indicating +whether the inviter can invite the invitee to the given room. + +The inviter is represented by their Matrix user ID (e.g. `@alice:example.com`), and the +invitee is represented by its medium (e.g. "email") and its address +(e.g. `alice@example.com`). See [the Matrix specification](https://matrix.org/docs/spec/appendices#pid-types) +for more information regarding third-party identifiers. + +For example, a call to this callback to send an invitation to the email address +`alice@example.com` would look like this: + +```python +await user_may_send_3pid_invite( + "@bob:example.com", # The inviter's user ID + "email", # The medium of the 3PID to invite + "alice@example.com", # The address of the 3PID to invite + "!some_room:example.com", # The ID of the room to send the invite into +) +``` + +**Note**: If the third-party identifier is already associated with a matrix user ID, +[`user_may_invite`](#user_may_invite) will be used instead. + ### `user_may_create_room` ```python diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index ec8863e397..ae4c8ab257 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -46,6 +46,7 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[ ] USER_MAY_JOIN_ROOM_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]] +USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bool]] USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]] USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK = Callable[ [str, List[str], List[Dict[str, str]]], Awaitable[bool] @@ -168,6 +169,9 @@ class SpamChecker: self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = [] self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = [] self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = [] + self._user_may_send_3pid_invite_callbacks: List[ + USER_MAY_SEND_3PID_INVITE_CALLBACK + ] = [] self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = [] self._user_may_create_room_with_invites_callbacks: List[ USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK @@ -191,6 +195,7 @@ class SpamChecker: check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None, user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None, user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, + user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None, user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, user_may_create_room_with_invites: Optional[ USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK @@ -215,6 +220,11 @@ class SpamChecker: if user_may_invite is not None: self._user_may_invite_callbacks.append(user_may_invite) + if user_may_send_3pid_invite is not None: + self._user_may_send_3pid_invite_callbacks.append( + user_may_send_3pid_invite, + ) + if user_may_create_room is not None: self._user_may_create_room_callbacks.append(user_may_create_room) @@ -304,6 +314,31 @@ class SpamChecker: return True + async def user_may_send_3pid_invite( + self, inviter_userid: str, medium: str, address: str, room_id: str + ) -> bool: + """Checks if a given user may invite a given threepid into the room + + If this method returns false, the threepid invite will be rejected. + + Note that if the threepid is already associated with a Matrix user ID, Synapse + will call user_may_invite with said user ID instead. + + Args: + inviter_userid: The user ID of the sender of the invitation + medium: The 3PID's medium (e.g. "email") + address: The 3PID's address (e.g. "alice@example.com") + room_id: The room ID + + Returns: + True if the user may send the invite, otherwise False + """ + for callback in self._user_may_send_3pid_invite_callbacks: + if await callback(inviter_userid, medium, address, room_id) is False: + return False + + return True + async def user_may_create_room(self, userid: str) -> bool: """Checks if a given user may create a room diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index c05461bf2a..eef337feeb 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1299,10 +1299,22 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if invitee: # Note that update_membership with an action of "invite" can raise # a ShadowBanError, but this was done above already. + # We don't check the invite against the spamchecker(s) here (through + # user_may_invite) because we'll do it further down the line anyway (in + # update_membership_locked). _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: + # Check if the spamchecker(s) allow this invite to go through. + if not await self.spam_checker.user_may_send_3pid_invite( + inviter_userid=requester.user.to_string(), + medium=medium, + address=address, + room_id=room_id, + ): + raise SynapseError(403, "Cannot send threepid invite") + stream_id = await self._make_and_store_3pid_invite( requester, id_server, diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index a41ec6a98f..376853fd65 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2531,3 +2531,73 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): """An alias which does not point to the room raises a SynapseError.""" self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) + + +class ThreepidInviteTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("thomas", "hackme") + self.tok = self.login("thomas", "hackme") + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_threepid_invite_spamcheck(self): + # Mock a few functions to prevent the test from failing due to failing to talk to + # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we + # can check its call_count later on during the test. + make_invite_mock = Mock(return_value=make_awaitable(0)) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock + self.hs.get_identity_handler().lookup_3pid = Mock( + return_value=make_awaitable(None), + ) + + # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it + # allow everything for now. + mock = Mock(return_value=make_awaitable(True)) + self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock) + + # Send a 3PID invite into the room and check that it succeeded. + email_to_invite = "teresa@example.com" + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEquals(channel.code, 200) + + # Check that the callback was called with the right params. + mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id) + + # Check that the call to send the invite was made. + make_invite_mock.assert_called_once() + + # Now change the return value of the callback to deny any invite and test that + # we can't send the invite. + mock.return_value = make_awaitable(False) + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEquals(channel.code, 403) + + # Also check that it stopped before calling _make_and_store_3pid_invite. + make_invite_mock.assert_called_once() -- cgit 1.5.1 From 86af6b2f0ef92a317900fd4a4f6d3436ff8a011c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 7 Oct 2021 12:20:03 +0100 Subject: Add a comment in _process_received_pdu (#11011) --- changelog.d/11011.misc | 1 + synapse/handlers/federation_event.py | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 changelog.d/11011.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11011.misc b/changelog.d/11011.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/11011.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 243be46267..0645ce9392 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -894,6 +894,9 @@ class FederationEventHandler: backfilled=backfilled, ) except AuthError as e: + # FIXME richvdh 2021/10/07 I don't think this is reachable. Let's log it + # for now + logger.exception("Unexpected AuthError from _check_event_auth") raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) await self._run_push_actions_and_persist_event(event, context, backfilled) -- cgit 1.5.1 From 96fe77c2546598449c1d423c125f84c92620b155 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 7 Oct 2021 12:43:25 +0100 Subject: Improve the logging in _auth_and_persist_outliers (#11010) Include the event ids being peristed --- changelog.d/11010.misc | 1 + synapse/handlers/federation_event.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 changelog.d/11010.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11010.misc b/changelog.d/11010.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/11010.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 0645ce9392..f640b417b3 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1161,7 +1161,10 @@ class FederationEventHandler: return logger.info( - "Persisting %i of %i remaining events", len(roots), len(event_map) + "Persisting %i of %i remaining outliers: %s", + len(roots), + len(event_map), + shortstr(e.event_id for e in roots), ) await self._auth_and_persist_fetched_events_inner(origin, room_id, roots) -- cgit 1.5.1 From e0bf34dada709776ae00843e47cd811d1cd195c6 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 7 Oct 2021 13:26:11 +0100 Subject: Don't alter directory entries for local users when setting a per-room nickname (#11002) Co-authored-by: Patrick Cloke --- changelog.d/11002.bugfix | 1 + synapse/handlers/user_directory.py | 20 +++++++++++++------- tests/handlers/test_user_directory.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 7 deletions(-) create mode 100644 changelog.d/11002.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11002.bugfix b/changelog.d/11002.bugfix new file mode 100644 index 0000000000..cf894a6314 --- /dev/null +++ b/changelog.d/11002.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where local users' per-room nicknames/avatars were visible to anyone who could see you in the user_directory. diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 97f60b5806..b7b1973346 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -203,6 +203,7 @@ class UserDirectoryHandler(StateDeltasHandler): public_value=Membership.JOIN, ) + is_remote = not self.is_mine_id(state_key) if change is MatchChange.now_false: # Need to check if the server left the room entirely, if so # we might need to remove all the users in that room @@ -224,15 +225,20 @@ class UserDirectoryHandler(StateDeltasHandler): else: logger.debug("Server is still in room: %r", room_id) - include_in_dir = not self.is_mine_id( - state_key - ) or await self.store.should_include_local_user_in_dir(state_key) + include_in_dir = ( + is_remote + or await self.store.should_include_local_user_in_dir(state_key) + ) if include_in_dir: if change is MatchChange.no_change: - # Handle any profile changes - await self._handle_profile_change( - state_key, room_id, prev_event_id, event_id - ) + # Handle any profile changes for remote users. + # (For local users we are not forced to scan membership + # events; instead the rest of the application calls + # `handle_local_profile_change`.) + if is_remote: + await self._handle_profile_change( + state_key, room_id, prev_event_id, event_id + ) continue if change is MatchChange.now_true: # The user joined diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 03fd5a3e2c..47217f0542 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -402,6 +402,40 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public3 = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) self.assertEqual(set(public3), {(alice, room2), (bob, room2)}) + def test_per_room_profile_doesnt_alter_directory_entry(self) -> None: + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + + # Alice should have a user directory entry created at registration. + users = self.get_success(self.user_dir_helper.get_profiles_in_user_directory()) + self.assertEqual( + users[alice], ProfileInfo(display_name="alice", avatar_url=None) + ) + + # Alice makes a room for herself. + room = self.helper.create_room_as(alice, is_public=True, tok=alice_token) + + # Alice sets a nickname unique to that room. + self.helper.send_state( + room, + "m.room.member", + { + "displayname": "Freddy Mercury", + "membership": "join", + }, + alice_token, + state_key=alice, + ) + + # Alice's display name remains the same in the user directory. + search_result = self.get_success(self.handler.search_users(bob, alice, 10)) + self.assertEqual( + search_result["results"], + [{"display_name": "alice", "avatar_url": None, "user_id": alice}], + 0, + ) + def test_private_room(self) -> None: """ A user can be searched for only by people that are either in a public -- cgit 1.5.1 From eb9ddc8c2e807e691fd1820f88f7c0bf43822661 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 8 Oct 2021 07:44:43 -0400 Subject: Remove the deprecated BaseHandler. (#11005) The shared ratelimit function was replaced with a dedicated RequestRatelimiter class (accessible from the HomeServer object). Other properties were copied to each sub-class that inherited from BaseHandler. --- changelog.d/11005.misc | 1 + synapse/api/ratelimiting.py | 86 +++++++++++++++++++++++ synapse/handlers/_base.py | 120 --------------------------------- synapse/handlers/admin.py | 7 +- synapse/handlers/auth.py | 8 +-- synapse/handlers/deactivate_account.py | 6 +- synapse/handlers/device.py | 10 +-- synapse/handlers/directory.py | 9 ++- synapse/handlers/events.py | 12 ++-- synapse/handlers/federation.py | 6 +- synapse/handlers/identity.py | 7 +- synapse/handlers/initial_sync.py | 8 +-- synapse/handlers/message.py | 7 +- synapse/handlers/profile.py | 11 +-- synapse/handlers/read_marker.py | 5 +- synapse/handlers/receipts.py | 6 +- synapse/handlers/register.py | 9 ++- synapse/handlers/room.py | 15 +++-- synapse/handlers/room_list.py | 7 +- synapse/handlers/room_member.py | 8 +-- synapse/handlers/saml.py | 7 +- synapse/handlers/search.py | 9 +-- synapse/handlers/set_password.py | 6 +- synapse/server.py | 11 ++- 24 files changed, 166 insertions(+), 215 deletions(-) create mode 100644 changelog.d/11005.misc delete mode 100644 synapse/handlers/_base.py (limited to 'synapse/handlers') diff --git a/changelog.d/11005.misc b/changelog.d/11005.misc new file mode 100644 index 0000000000..a893591971 --- /dev/null +++ b/changelog.d/11005.misc @@ -0,0 +1 @@ +Remove the deprecated `BaseHandler` object. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index cbdd74025b..e8964097d3 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,6 +17,7 @@ from collections import OrderedDict from typing import Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.config.ratelimiting import RateLimitConfig from synapse.storage.databases.main import DataStore from synapse.types import Requester from synapse.util import Clock @@ -233,3 +234,88 @@ class Ratelimiter: raise LimitExceededError( retry_after_ms=int(1000 * (time_allowed - time_now_s)) ) + + +class RequestRatelimiter: + def __init__( + self, + store: DataStore, + clock: Clock, + rc_message: RateLimitConfig, + rc_admin_redaction: Optional[RateLimitConfig], + ): + self.store = store + self.clock = clock + + # The rate_hz and burst_count are overridden on a per-user basis + self.request_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=0, burst_count=0 + ) + self._rc_message = rc_message + + # Check whether ratelimiting room admin message redaction is enabled + # by the presence of rate limits in the config + if rc_admin_redaction: + self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( + store=self.store, + clock=self.clock, + rate_hz=rc_admin_redaction.per_second, + burst_count=rc_admin_redaction.burst_count, + ) + else: + self.admin_redaction_ratelimiter = None + + async def ratelimit( + self, + requester: Requester, + update: bool = True, + is_admin_redaction: bool = False, + ) -> None: + """Ratelimits requests. + + Args: + requester + update: Whether to record that a request is being processed. + Set to False when doing multiple checks for one request (e.g. + to check up front if we would reject the request), and set to + True for the last call for a given request. + is_admin_redaction: Whether this is a room admin/moderator + redacting an event. If so then we may apply different + ratelimits depending on config. + + Raises: + LimitExceededError if the request should be ratelimited + """ + user_id = requester.user.to_string() + + # The AS user itself is never rate limited. + app_service = self.store.get_app_service_by_user_id(user_id) + if app_service is not None: + return # do not ratelimit app service senders + + messages_per_second = self._rc_message.per_second + burst_count = self._rc_message.burst_count + + # Check if there is a per user override in the DB. + override = await self.store.get_ratelimit_for_user(user_id) + if override: + # If overridden with a null Hz then ratelimiting has been entirely + # disabled for the user + if not override.messages_per_second: + return + + messages_per_second = override.messages_per_second + burst_count = override.burst_count + + if is_admin_redaction and self.admin_redaction_ratelimiter: + # If we have separate config for admin redactions, use a separate + # ratelimiter as to not have user_ids clash + await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) + else: + # Override rate and burst count per-user + await self.request_ratelimiter.ratelimit( + requester, + rate_hz=messages_per_second, + burst_count=burst_count, + update=update, + ) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py deleted file mode 100644 index 0ccef884e7..0000000000 --- a/synapse/handlers/_base.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2014 - 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING, Optional - -from synapse.api.ratelimiting import Ratelimiter -from synapse.types import Requester - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class BaseHandler: - """ - Common base class for the event handlers. - - Deprecated: new code should not use this. Instead, Handler classes should define the - fields they actually need. The utility methods should either be factored out to - standalone helper functions, or to different Handler classes. - """ - - def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() - self.auth = hs.get_auth() - self.notifier = hs.get_notifier() - self.state_handler = hs.get_state_handler() - self.distributor = hs.get_distributor() - self.clock = hs.get_clock() - self.hs = hs - - # The rate_hz and burst_count are overridden on a per-user basis - self.request_ratelimiter = Ratelimiter( - store=self.store, clock=self.clock, rate_hz=0, burst_count=0 - ) - self._rc_message = self.hs.config.ratelimiting.rc_message - - # Check whether ratelimiting room admin message redaction is enabled - # by the presence of rate limits in the config - if self.hs.config.ratelimiting.rc_admin_redaction: - self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( - store=self.store, - clock=self.clock, - rate_hz=self.hs.config.ratelimiting.rc_admin_redaction.per_second, - burst_count=self.hs.config.ratelimiting.rc_admin_redaction.burst_count, - ) - else: - self.admin_redaction_ratelimiter = None - - self.server_name = hs.hostname - - self.event_builder_factory = hs.get_event_builder_factory() - - async def ratelimit( - self, - requester: Requester, - update: bool = True, - is_admin_redaction: bool = False, - ) -> None: - """Ratelimits requests. - - Args: - requester - update: Whether to record that a request is being processed. - Set to False when doing multiple checks for one request (e.g. - to check up front if we would reject the request), and set to - True for the last call for a given request. - is_admin_redaction: Whether this is a room admin/moderator - redacting an event. If so then we may apply different - ratelimits depending on config. - - Raises: - LimitExceededError if the request should be ratelimited - """ - user_id = requester.user.to_string() - - # The AS user itself is never rate limited. - app_service = self.store.get_app_service_by_user_id(user_id) - if app_service is not None: - return # do not ratelimit app service senders - - messages_per_second = self._rc_message.per_second - burst_count = self._rc_message.burst_count - - # Check if there is a per user override in the DB. - override = await self.store.get_ratelimit_for_user(user_id) - if override: - # If overridden with a null Hz then ratelimiting has been entirely - # disabled for the user - if not override.messages_per_second: - return - - messages_per_second = override.messages_per_second - burst_count = override.burst_count - - if is_admin_redaction and self.admin_redaction_ratelimiter: - # If we have separate config for admin redactions, use a separate - # ratelimiter as to not have user_ids clash - await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) - else: - # Override rate and burst count per-user - await self.request_ratelimiter.ratelimit( - requester, - rate_hz=messages_per_second, - burst_count=burst_count, - update=update, - ) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index bfa7f2c545..a53cd62d3c 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -21,18 +21,15 @@ from synapse.events import EventBase from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class AdminHandler(BaseHandler): +class AdminHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2d0f3d566c..f4612a5b92 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -52,7 +52,6 @@ from synapse.api.errors import ( UserDeactivatedError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.handlers._base import BaseHandler from synapse.handlers.ui_auth import ( INTERACTIVE_AUTH_CHECKERS, UIAuthSessionDataConstants, @@ -186,12 +185,13 @@ class LoginTokenAttributes: auth_provider_id = attr.ib(type=str) -class AuthHandler(BaseHandler): +class AuthHandler: SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.clock = hs.get_clock() self.checkers: Dict[str, UserInteractiveAuthChecker] = {} for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 12bdca7445..e88c3c27ce 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -19,19 +19,17 @@ from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import Requester, UserID, create_requester -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class DeactivateAccountHandler(BaseHandler): +class DeactivateAccountHandler: """Handler which deals with deactivating user accounts.""" def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 35334725d7..75e6019760 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -40,8 +40,6 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -50,14 +48,16 @@ logger = logging.getLogger(__name__) MAX_DEVICE_DISPLAY_NAME_LEN = 100 -class DeviceWorkerHandler(BaseHandler): +class DeviceWorkerHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.clock = hs.get_clock() self.hs = hs + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() self.state = hs.get_state_handler() self.state_store = hs.get_storage().state self._auth_handler = hs.get_auth_handler() + self.server_name = hs.hostname @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 9078781d5a..14ed7d9879 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -31,18 +31,16 @@ from synapse.appservice import ApplicationService from synapse.storage.databases.main.directory import RoomAliasMapping from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class DirectoryHandler(BaseHandler): +class DirectoryHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.auth = hs.get_auth() + self.hs = hs self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() @@ -51,6 +49,7 @@ class DirectoryHandler(BaseHandler): self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.server.require_membership_for_aliases self.third_party_event_rules = hs.get_third_party_event_rules() + self.server_name = hs.hostname self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 4b3f037072..1f64534a8a 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -25,8 +25,6 @@ from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -34,11 +32,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class EventStreamHandler(BaseHandler): +class EventStreamHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() self.clock = hs.get_clock() + self.hs = hs self.notifier = hs.get_notifier() self.state = hs.get_state_handler() @@ -138,9 +136,9 @@ class EventStreamHandler(BaseHandler): return chunk -class EventHandler(BaseHandler): +class EventHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() self.storage = hs.get_storage() async def get_event( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 043ca4a224..3e341bd287 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -53,7 +53,6 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.federation.federation_client import InvalidResponseError -from synapse.handlers._base import BaseHandler from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, @@ -78,15 +77,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class FederationHandler(BaseHandler): +class FederationHandler: """Handles general incoming federation requests Incoming events are *not* handled here, for which see FederationEventHandler. """ def __init__(self, hs: "HomeServer"): - super().__init__(hs) - self.hs = hs self.store = hs.get_datastore() @@ -99,6 +96,7 @@ class FederationHandler(BaseHandler): self.is_mine_id = hs.is_mine_id self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self.event_builder_factory = hs.get_event_builder_factory() self._event_auth_handler = hs.get_event_auth_handler() self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self.config = hs.config diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c881475c25..9c319b5383 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -39,8 +39,6 @@ from synapse.util.stringutils import ( valid_id_server_location, ) -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -49,10 +47,9 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" -class IdentityHandler(BaseHandler): +class IdentityHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 9ad39a65d8..d4e4556155 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -31,8 +31,6 @@ from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -40,9 +38,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class InitialSyncHandler(BaseHandler): +class InitialSyncHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.state_handler = hs.get_state_handler() self.hs = hs self.state = hs.get_state_handler() self.clock = hs.get_clock() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ccd7827207..4de9f4b828 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -62,8 +62,6 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.events.third_party_rules import ThirdPartyEventRules from synapse.server import HomeServer @@ -433,8 +431,7 @@ class EventCreationHandler: self.send_event = ReplicationSendEventRestServlet.make_client(hs) - # This is only used to get at ratelimit function - self.base_handler = BaseHandler(hs) + self.request_ratelimiter = hs.get_request_ratelimiter() # We arbitrarily limit concurrent event creation for a room to 5. # This is to stop us from diverging history *too* much. @@ -1322,7 +1319,7 @@ class EventCreationHandler: original_event and event.sender != original_event.sender ) - await self.base_handler.ratelimit( + await self.request_ratelimiter.ratelimit( requester, is_admin_redaction=is_admin_redaction ) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 2e19706c69..e6c3cf585b 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -32,8 +32,6 @@ from synapse.types import ( get_domain_from_id, ) -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -43,7 +41,7 @@ MAX_DISPLAYNAME_LEN = 256 MAX_AVATAR_URL_LEN = 1000 -class ProfileHandler(BaseHandler): +class ProfileHandler: """Handles fetching and updating user profile information. ProfileHandler can be instantiated directly on workers and will @@ -54,7 +52,9 @@ class ProfileHandler(BaseHandler): PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.hs = hs self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -62,6 +62,7 @@ class ProfileHandler(BaseHandler): ) self.user_directory_handler = hs.get_user_directory_handler() + self.request_ratelimiter = hs.get_request_ratelimiter() if hs.config.worker.run_background_tasks: self.clock.looping_call( @@ -346,7 +347,7 @@ class ProfileHandler(BaseHandler): if not self.hs.is_mine(target_user): return - await self.ratelimit(requester) + await self.request_ratelimiter.ratelimit(requester) # Do not actually update the room state for shadow-banned users. if requester.shadow_banned: diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index bd8160e7ed..58593e570e 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -17,17 +17,14 @@ from typing import TYPE_CHECKING from synapse.util.async_helpers import Linearizer -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class ReadMarkerHandler(BaseHandler): +class ReadMarkerHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) self.server_name = hs.config.server.server_name self.store = hs.get_datastore() self.account_data_handler = hs.get_account_data_handler() diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index f21f33ada2..374e961e3b 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -16,7 +16,6 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse.api.constants import ReadReceiptEventFields from synapse.appservice import ApplicationService -from synapse.handlers._base import BaseHandler from synapse.streams import EventSource from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id @@ -26,10 +25,9 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ReceiptsHandler(BaseHandler): +class ReceiptsHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.notifier = hs.get_notifier() self.server_name = hs.config.server.server_name self.store = hs.get_datastore() self.event_auth_handler = hs.get_event_auth_handler() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 441af7a848..a0e6a01775 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -41,8 +41,6 @@ from synapse.spam_checker_api import RegistrationBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -85,9 +83,10 @@ class LoginDict(TypedDict): refresh_token: Optional[str] -class RegistrationHandler(BaseHandler): +class RegistrationHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.clock = hs.get_clock() self.hs = hs self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() @@ -515,7 +514,7 @@ class RegistrationHandler(BaseHandler): # we don't have a local user in the room to craft up an invite with. requires_invite = await self.store.is_host_joined( room_id, - self.server_name, + self._server_name, ) if requires_invite: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d40dbd761d..7072bca1fc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -76,8 +76,6 @@ from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -88,15 +86,18 @@ id_server_scheme = "https://" FIVE_MINUTES_IN_MS = 5 * 60 * 1000 -class RoomCreationHandler(BaseHandler): +class RoomCreationHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.hs = hs self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self._event_auth_handler = hs.get_event_auth_handler() self.config = hs.config + self.request_ratelimiter = hs.get_request_ratelimiter() # Room state based off defined presets self._presets_dict: Dict[str, Dict[str, Any]] = { @@ -162,7 +163,7 @@ class RoomCreationHandler(BaseHandler): Raises: ShadowBanError if the requester is shadow-banned. """ - await self.ratelimit(requester) + await self.request_ratelimiter.ratelimit(requester) user_id = requester.user.to_string() @@ -665,7 +666,7 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(403, "You are not permitted to create rooms") if ratelimit: - await self.ratelimit(requester) + await self.request_ratelimiter.ratelimit(requester) room_version_id = config.get( "room_version", self.config.server.default_room_version.identifier diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index c3d4199ed1..ba7a14d651 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -36,8 +36,6 @@ from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.response_cache import ResponseCache -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -49,9 +47,10 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) -class RoomListHandler(BaseHandler): +class RoomListHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.hs = hs self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]] diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index eef337feeb..74e6c7eca6 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -51,8 +51,6 @@ from synapse.types import ( from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_left_room -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -118,9 +116,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, ) - # This is only used to get at the ratelimit function. It's fine there are - # multiple of these as it doesn't store state. - self.base_handler = BaseHandler(hs) + self.request_ratelimiter = hs.get_request_ratelimiter() @abc.abstractmethod async def _remote_join( @@ -1275,7 +1271,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. - await self.base_handler.ratelimit(requester) + await self.request_ratelimiter.ratelimit(requester) can_invite = await self.third_party_event_rules.check_threepid_can_be_invited( medium, address, room_id diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 2fed9f377a..727d75a50c 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -22,7 +22,6 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError -from synapse.handlers._base import BaseHandler from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest @@ -51,9 +50,11 @@ class Saml2SessionData: ui_auth_session_id: Optional[str] = None -class SamlHandler(BaseHandler): +class SamlHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.server_name = hs.hostname self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config) self._saml_idp_entityid = hs.config.saml2.saml2_idp_entityid diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 6d3333ee00..a3ffa26be8 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -26,17 +26,18 @@ from synapse.storage.state import StateFilter from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class SearchHandler(BaseHandler): +class SearchHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.state_handler = hs.get_state_handler() + self.clock = hs.get_clock() + self.hs = hs self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index a63fac8283..706ad72761 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -17,19 +17,17 @@ from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class SetPasswordHandler(BaseHandler): +class SetPasswordHandler: """Handler which deals with changing user account passwords""" def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/server.py b/synapse/server.py index 637eb15b78..0783df41d4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -39,7 +39,7 @@ from twisted.web.resource import IResource from synapse.api.auth import Auth from synapse.api.filtering import Filtering -from synapse.api.ratelimiting import Ratelimiter +from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.config.homeserver import HomeServerConfig @@ -816,3 +816,12 @@ class HomeServer(metaclass=abc.ABCMeta): def should_send_federation(self) -> bool: "Should this server be sending federation traffic directly?" return self.config.worker.send_federation + + @cache_in_self + def get_request_ratelimiter(self) -> RequestRatelimiter: + return RequestRatelimiter( + self.get_datastore(), + self.get_clock(), + self.config.ratelimiting.rc_message, + self.config.ratelimiting.rc_admin_redaction, + ) -- cgit 1.5.1 From 670a8d9a1e18159917ca1b4f8e5af48a0b258f5e Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 8 Oct 2021 12:52:48 +0100 Subject: Fix overwriting profile when making room public (#11003) This splits apart `handle_new_user` into a function which adds an entry to the `user_directory` and a function which updates the room sharing tables. I plan to continue doing more of this kind of refactoring to clarify the implementation. --- changelog.d/11003.bugfix | 1 + synapse/handlers/user_directory.py | 63 +++++++++++++++++-------------- tests/handlers/test_user_directory.py | 71 ++++++++++++++++++++++++++++++++++- 3 files changed, 104 insertions(+), 31 deletions(-) create mode 100644 changelog.d/11003.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11003.bugfix b/changelog.d/11003.bugfix new file mode 100644 index 0000000000..0786f1b886 --- /dev/null +++ b/changelog.d/11003.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a user's per-room nickname/avatar would overwrite their profile in the user directory when a room was made public. \ No newline at end of file diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index b7b1973346..8810f048ba 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -242,18 +242,15 @@ class UserDirectoryHandler(StateDeltasHandler): continue if change is MatchChange.now_true: # The user joined - event = await self.store.get_event(event_id, allow_none=True) - # It isn't expected for this event to not exist, but we - # don't want the entire background process to break. - if event is None: - continue - - profile = ProfileInfo( - avatar_url=event.content.get("avatar_url"), - display_name=event.content.get("displayname"), - ) - - await self._handle_new_user(room_id, state_key, profile) + # This may be the first time we've seen a remote user. If + # so, ensure we have a directory entry for them. (We don't + # need to do this for local users: their directory entry + # is created at the point of registration. + if is_remote: + await self._upsert_directory_entry_for_remote_user( + state_key, event_id + ) + await self._track_user_joined_room(room_id, state_key) else: # The user left await self._handle_remove_user(room_id, state_key) else: @@ -303,7 +300,7 @@ class UserDirectoryHandler(StateDeltasHandler): room_id ) - logger.debug("Change: %r, publicness: %r", publicness, is_public) + logger.debug("Publicness change: %r, is_public: %r", publicness, is_public) if publicness is MatchChange.now_true and not is_public: # If we became world readable but room isn't currently public then @@ -314,42 +311,50 @@ class UserDirectoryHandler(StateDeltasHandler): # ignore the change return - other_users_in_room_with_profiles = ( - await self.store.get_users_in_room_with_profiles(room_id) - ) + users_in_room = await self.store.get_users_in_room(room_id) # Remove every user from the sharing tables for that room. - for user_id in other_users_in_room_with_profiles.keys(): + for user_id in users_in_room: await self.store.remove_user_who_share_room(user_id, room_id) # Then, re-add them to the tables. - # NOTE: this is not the most efficient method, as handle_new_user sets + # NOTE: this is not the most efficient method, as _track_user_joined_room sets # up local_user -> other_user and other_user_whos_local -> local_user, # which when ran over an entire room, will result in the same values # being added multiple times. The batching upserts shouldn't make this # too bad, though. - for user_id, profile in other_users_in_room_with_profiles.items(): - await self._handle_new_user(room_id, user_id, profile) + for user_id in users_in_room: + await self._track_user_joined_room(room_id, user_id) - async def _handle_new_user( - self, room_id: str, user_id: str, profile: ProfileInfo + async def _upsert_directory_entry_for_remote_user( + self, user_id: str, event_id: str ) -> None: - """Called when we might need to add user to directory - - Args: - room_id: The room ID that user joined or started being public - user_id + """A remote user has just joined a room. Ensure they have an entry in + the user directory. The caller is responsible for making sure they're + remote. """ + event = await self.store.get_event(event_id, allow_none=True) + # It isn't expected for this event to not exist, but we + # don't want the entire background process to break. + if event is None: + return + logger.debug("Adding new user to dir, %r", user_id) await self.store.update_profile_in_user_dir( - user_id, profile.display_name, profile.avatar_url + user_id, event.content.get("displayname"), event.content.get("avatar_url") ) + async def _track_user_joined_room(self, room_id: str, user_id: str) -> None: + """Someone's just joined a room. Update `users_in_public_rooms` or + `users_who_share_private_rooms` as appropriate. + + The caller is responsible for ensuring that the given user is not excluded + from the user directory. + """ is_public = await self.store.is_room_world_readable_or_publicly_joinable( room_id ) - # Now we update users who share rooms with users. other_users_in_room = await self.store.get_users_in_room(room_id) if is_public: diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 47217f0542..db65253773 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -372,8 +372,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Alice makes two rooms. Bob joins one of them. room1 = self.helper.create_room_as(alice, tok=alice_token) room2 = self.helper.create_room_as(alice, tok=alice_token) - print("room1=", room1) - print("room2=", room2) self.helper.join(room1, bob, tok=bob_token) # The user sharing tables should have been updated. @@ -436,6 +434,75 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): 0, ) + def test_making_room_public_doesnt_alter_directory_entry(self) -> None: + """Per-room names shouldn't go to the directory when the room becomes public. + + This isn't about preventing a leak (the room is now public, so the nickname + is too). It's about preserving the invariant that we only show a user's public + profile in the user directory results. + + I made this a Synapse test case rather than a Complement one because + I think this is (strictly speaking) an implementation choice. Synapse + has chosen to only ever use the public profile when responding to a user + directory search. There's no privacy leak here, because making the room + public discloses the per-room name. + + The spec doesn't mandate anything about _how_ a user + should appear in a /user_directory/search result. Hypothetical example: + suppose Bob searches for Alice. When representing Alice in a search + result, it's reasonable to use any of Alice's nicknames that Bob is + aware of. Heck, maybe we even want to use lots of them in a combined + displayname like `Alice (aka "ali", "ally", "41iC3")`. + """ + + # TODO the same should apply when Alice is a remote user. + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + + # Alice and Bob are in a private room. + room = self.helper.create_room_as(alice, is_public=False, tok=alice_token) + self.helper.invite(room, src=alice, targ=bob, tok=alice_token) + self.helper.join(room, user=bob, tok=bob_token) + + # Alice has a nickname unique to that room. + + self.helper.send_state( + room, + "m.room.member", + { + "displayname": "Freddy Mercury", + "membership": "join", + }, + alice_token, + state_key=alice, + ) + + # Check Alice isn't recorded as being in a public room. + public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + self.assertNotIn((alice, room), public) + + # One of them makes the room public. + self.helper.send_state( + room, + "m.room.join_rules", + {"join_rule": "public"}, + alice_token, + ) + + # Check that Alice is now recorded as being in a public room + public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + self.assertIn((alice, room), public) + + # Alice's display name remains the same in the user directory. + search_result = self.get_success(self.handler.search_users(bob, alice, 10)) + self.assertEqual( + search_result["results"], + [{"display_name": "alice", "avatar_url": None, "user_id": alice}], + 0, + ) + def test_private_room(self) -> None: """ A user can be searched for only by people that are either in a public -- cgit 1.5.1 From a7d22c36dbbbdd396aeb8938b57b5fd7edb689f3 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 8 Oct 2021 18:35:00 -0500 Subject: Refactor MSC2716 `/batch_send` endpoint into separate handler functions (#10974) --- changelog.d/10974.misc | 1 + synapse/handlers/room_batch.py | 423 ++++++++++++++++++++++++++++++++++++++ synapse/rest/client/room_batch.py | 339 +++++------------------------- synapse/server.py | 5 + 4 files changed, 485 insertions(+), 283 deletions(-) create mode 100644 changelog.d/10974.misc create mode 100644 synapse/handlers/room_batch.py (limited to 'synapse/handlers') diff --git a/changelog.d/10974.misc b/changelog.d/10974.misc new file mode 100644 index 0000000000..8695b378aa --- /dev/null +++ b/changelog.d/10974.misc @@ -0,0 +1 @@ +Refactor [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` mega function into smaller handler functions. diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py new file mode 100644 index 0000000000..51dd4e7555 --- /dev/null +++ b/synapse/handlers/room_batch.py @@ -0,0 +1,423 @@ +import logging +from typing import TYPE_CHECKING, List, Tuple + +from synapse.api.constants import EventContentFields, EventTypes +from synapse.appservice import ApplicationService +from synapse.http.servlet import assert_params_in_dict +from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RoomBatchHandler: + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.store = hs.get_datastore() + self.state_store = hs.get_storage().state + self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + + async def inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int: + """Finds the depth which would sort it after the most-recent + prev_event_id but before the successors of those events. If no + successors are found, we assume it's an historical extremity part of the + current batch and use the same depth of the prev_event_ids. + + Args: + prev_event_ids: List of prev event IDs + + Returns: + Inherited depth + """ + ( + most_recent_prev_event_id, + most_recent_prev_event_depth, + ) = await self.store.get_max_depth_of(prev_event_ids) + + # We want to insert the historical event after the `prev_event` but before the successor event + # + # We inherit depth from the successor event instead of the `prev_event` + # because events returned from `/messages` are first sorted by `topological_ordering` + # which is just the `depth` and then tie-break with `stream_ordering`. + # + # We mark these inserted historical events as "backfilled" which gives them a + # negative `stream_ordering`. If we use the same depth as the `prev_event`, + # then our historical event will tie-break and be sorted before the `prev_event` + # when it should come after. + # + # We want to use the successor event depth so they appear after `prev_event` because + # it has a larger `depth` but before the successor event because the `stream_ordering` + # is negative before the successor event. + successor_event_ids = await self.store.get_successor_events( + [most_recent_prev_event_id] + ) + + # If we can't find any successor events, then it's a forward extremity of + # historical messages and we can just inherit from the previous historical + # event which we can already assume has the correct depth where we want + # to insert into. + if not successor_event_ids: + depth = most_recent_prev_event_depth + else: + ( + _, + oldest_successor_depth, + ) = await self.store.get_min_depth_of(successor_event_ids) + + depth = oldest_successor_depth + + return depth + + def create_insertion_event_dict( + self, sender: str, room_id: str, origin_server_ts: int + ) -> JsonDict: + """Creates an event dict for an "insertion" event with the proper fields + and a random batch ID. + + Args: + sender: The event author MXID + room_id: The room ID that the event belongs to + origin_server_ts: Timestamp when the event was sent + + Returns: + The new event dictionary to insert. + """ + + next_batch_id = random_string(8) + insertion_event = { + "type": EventTypes.MSC2716_INSERTION, + "sender": sender, + "room_id": room_id, + "content": { + EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id, + EventContentFields.MSC2716_HISTORICAL: True, + }, + "origin_server_ts": origin_server_ts, + } + + return insertion_event + + async def create_requester_for_user_id_from_app_service( + self, user_id: str, app_service: ApplicationService + ) -> Requester: + """Creates a new requester for the given user_id + and validates that the app service is allowed to control + the given user. + + Args: + user_id: The author MXID that the app service is controlling + app_service: The app service that controls the user + + Returns: + Requester object + """ + + await self.auth.validate_appservice_can_control_user_id(app_service, user_id) + + return create_requester(user_id, app_service=app_service) + + async def get_most_recent_auth_event_ids_from_event_id_list( + self, event_ids: List[str] + ) -> List[str]: + """Find the most recent auth event ids (derived from state events) that + allowed that message to be sent. We will use this as a base + to auth our historical messages against. + + Args: + event_ids: List of event ID's to look at + + Returns: + List of event ID's + """ + + ( + most_recent_prev_event_id, + _, + ) = await self.store.get_max_depth_of(event_ids) + # mapping from (type, state_key) -> state_event_id + prev_state_map = await self.state_store.get_state_ids_for_event( + most_recent_prev_event_id + ) + # List of state event ID's + prev_state_ids = list(prev_state_map.values()) + auth_event_ids = prev_state_ids + + return auth_event_ids + + async def persist_state_events_at_start( + self, + state_events_at_start: List[JsonDict], + room_id: str, + initial_auth_event_ids: List[str], + app_service_requester: Requester, + ) -> List[str]: + """Takes all `state_events_at_start` event dictionaries and creates/persists + them as floating state events which don't resolve into the current room state. + They are floating because they reference a fake prev_event which doesn't connect + to the normal DAG at all. + + Args: + state_events_at_start: + room_id: Room where you want the events persisted in. + initial_auth_event_ids: These will be the auth_events for the first + state event created. Each event created afterwards will be + added to the list of auth events for the next state event + created. + app_service_requester: The requester of an application service. + + Returns: + List of state event ID's we just persisted + """ + assert app_service_requester.app_service + + state_event_ids_at_start = [] + auth_event_ids = initial_auth_event_ids.copy() + for state_event in state_events_at_start: + assert_params_in_dict( + state_event, ["type", "origin_server_ts", "content", "sender"] + ) + + logger.debug( + "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", + state_event, + auth_event_ids, + ) + + event_dict = { + "type": state_event["type"], + "origin_server_ts": state_event["origin_server_ts"], + "content": state_event["content"], + "room_id": room_id, + "sender": state_event["sender"], + "state_key": state_event["state_key"], + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + # Make the state events float off on their own so we don't have a + # bunch of `@mxid joined the room` noise between each batch + fake_prev_event_id = "$" + random_string(43) + + # TODO: This is pretty much the same as some other code to handle inserting state in this file + if event_dict["type"] == EventTypes.Member: + membership = event_dict["content"].get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + await self.create_requester_for_user_id_from_app_service( + state_event["sender"], app_service_requester.app_service + ), + target=UserID.from_string(event_dict["state_key"]), + room_id=room_id, + action=membership, + content=event_dict["content"], + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + else: + # TODO: Add some complement tests that adds state that is not member joins + # and will use this code path. Maybe we only want to support join state events + # and can get rid of this `else`? + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self.create_requester_for_user_id_from_app_service( + state_event["sender"], app_service_requester.app_service + ), + event_dict, + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + event_id = event.event_id + + state_event_ids_at_start.append(event_id) + auth_event_ids.append(event_id) + + return state_event_ids_at_start + + async def persist_historical_events( + self, + events_to_create: List[JsonDict], + room_id: str, + initial_prev_event_ids: List[str], + inherited_depth: int, + auth_event_ids: List[str], + app_service_requester: Requester, + ) -> List[str]: + """Create and persists all events provided sequentially. Handles the + complexity of creating events in chronological order so they can + reference each other by prev_event but still persists in + reverse-chronoloical order so they have the correct + (topological_ordering, stream_ordering) and sort correctly from + /messages. + + Args: + events_to_create: List of historical events to create in JSON + dictionary format. + room_id: Room where you want the events persisted in. + initial_prev_event_ids: These will be the prev_events for the first + event created. Each event created afterwards will point to the + previous event created. + inherited_depth: The depth to create the events at (you will + probably by calling inherit_depth_from_prev_ids(...)). + auth_event_ids: Define which events allow you to create the given + event in the room. + app_service_requester: The requester of an application service. + + Returns: + List of persisted event IDs + """ + assert app_service_requester.app_service + + prev_event_ids = initial_prev_event_ids.copy() + + event_ids = [] + events_to_persist = [] + for ev in events_to_create: + assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) + + event_dict = { + "type": ev["type"], + "origin_server_ts": ev["origin_server_ts"], + "content": ev["content"], + "room_id": room_id, + "sender": ev["sender"], # requester.user.to_string(), + "prev_events": prev_event_ids.copy(), + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + event, context = await self.event_creation_handler.create_event( + await self.create_requester_for_user_id_from_app_service( + ev["sender"], app_service_requester.app_service + ), + event_dict, + prev_event_ids=event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + logger.debug( + "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", + event, + prev_event_ids, + auth_event_ids, + ) + + assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( + event.sender, + ) + + events_to_persist.append((event, context)) + event_id = event.event_id + + event_ids.append(event_id) + prev_event_ids = [event_id] + + # Persist events in reverse-chronological order so they have the + # correct stream_ordering as they are backfilled (which decrements). + # Events are sorted by (topological_ordering, stream_ordering) + # where topological_ordering is just depth. + for (event, context) in reversed(events_to_persist): + await self.event_creation_handler.handle_new_client_event( + await self.create_requester_for_user_id_from_app_service( + event["sender"], app_service_requester.app_service + ), + event=event, + context=context, + ) + + return event_ids + + async def handle_batch_of_events( + self, + events_to_create: List[JsonDict], + room_id: str, + batch_id_to_connect_to: str, + initial_prev_event_ids: List[str], + inherited_depth: int, + auth_event_ids: List[str], + app_service_requester: Requester, + ) -> Tuple[List[str], str]: + """ + Handles creating and persisting all of the historical events as well + as insertion and batch meta events to make the batch navigable in the DAG. + + Args: + events_to_create: List of historical events to create in JSON + dictionary format. + room_id: Room where you want the events created in. + batch_id_to_connect_to: The batch_id from the insertion event you + want this batch to connect to. + initial_prev_event_ids: These will be the prev_events for the first + event created. Each event created afterwards will point to the + previous event created. + inherited_depth: The depth to create the events at (you will + probably by calling inherit_depth_from_prev_ids(...)). + auth_event_ids: Define which events allow you to create the given + event in the room. + app_service_requester: The requester of an application service. + + Returns: + Tuple containing a list of created events and the next_batch_id + """ + + # Connect this current batch to the insertion event from the previous batch + last_event_in_batch = events_to_create[-1] + batch_event = { + "type": EventTypes.MSC2716_BATCH, + "sender": app_service_requester.user.to_string(), + "room_id": room_id, + "content": { + EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to, + EventContentFields.MSC2716_HISTORICAL: True, + }, + # Since the batch event is put at the end of the batch, + # where the newest-in-time event is, copy the origin_server_ts from + # the last event we're inserting + "origin_server_ts": last_event_in_batch["origin_server_ts"], + } + # Add the batch event to the end of the batch (newest-in-time) + events_to_create.append(batch_event) + + # Add an "insertion" event to the start of each batch (next to the oldest-in-time + # event in the batch) so the next batch can be connected to this one. + insertion_event = self.create_insertion_event_dict( + sender=app_service_requester.user.to_string(), + room_id=room_id, + # Since the insertion event is put at the start of the batch, + # where the oldest-in-time event is, copy the origin_server_ts from + # the first event we're inserting + origin_server_ts=events_to_create[0]["origin_server_ts"], + ) + next_batch_id = insertion_event["content"][ + EventContentFields.MSC2716_NEXT_BATCH_ID + ] + # Prepend the insertion event to the start of the batch (oldest-in-time) + events_to_create = [insertion_event] + events_to_create + + # Create and persist all of the historical events + event_ids = await self.persist_historical_events( + events_to_create=events_to_create, + room_id=room_id, + initial_prev_event_ids=initial_prev_event_ids, + inherited_depth=inherited_depth, + auth_event_ids=auth_event_ids, + app_service_requester=app_service_requester, + ) + + return event_ids, next_batch_id diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 1dffcc3147..38ad4c2447 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -15,13 +15,12 @@ import logging import re from http import HTTPStatus -from typing import TYPE_CHECKING, Awaitable, List, Tuple +from typing import TYPE_CHECKING, Awaitable, Tuple from twisted.web.server import Request -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.appservice import ApplicationService from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -32,7 +31,7 @@ from synapse.http.servlet import ( ) from synapse.http.site import SynapseRequest from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import JsonDict from synapse.util.stringutils import random_string if TYPE_CHECKING: @@ -77,102 +76,12 @@ class RoomBatchSendEventRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.hs = hs self.store = hs.get_datastore() - self.state_store = hs.get_storage().state self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() + self.room_batch_handler = hs.get_room_batch_handler() self.txns = HttpTransactionCache(hs) - async def _inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int: - ( - most_recent_prev_event_id, - most_recent_prev_event_depth, - ) = await self.store.get_max_depth_of(prev_event_ids) - - # We want to insert the historical event after the `prev_event` but before the successor event - # - # We inherit depth from the successor event instead of the `prev_event` - # because events returned from `/messages` are first sorted by `topological_ordering` - # which is just the `depth` and then tie-break with `stream_ordering`. - # - # We mark these inserted historical events as "backfilled" which gives them a - # negative `stream_ordering`. If we use the same depth as the `prev_event`, - # then our historical event will tie-break and be sorted before the `prev_event` - # when it should come after. - # - # We want to use the successor event depth so they appear after `prev_event` because - # it has a larger `depth` but before the successor event because the `stream_ordering` - # is negative before the successor event. - successor_event_ids = await self.store.get_successor_events( - [most_recent_prev_event_id] - ) - - # If we can't find any successor events, then it's a forward extremity of - # historical messages and we can just inherit from the previous historical - # event which we can already assume has the correct depth where we want - # to insert into. - if not successor_event_ids: - depth = most_recent_prev_event_depth - else: - ( - _, - oldest_successor_depth, - ) = await self.store.get_min_depth_of(successor_event_ids) - - depth = oldest_successor_depth - - return depth - - def _create_insertion_event_dict( - self, sender: str, room_id: str, origin_server_ts: int - ) -> JsonDict: - """Creates an event dict for an "insertion" event with the proper fields - and a random batch ID. - - Args: - sender: The event author MXID - room_id: The room ID that the event belongs to - origin_server_ts: Timestamp when the event was sent - - Returns: - The new event dictionary to insert. - """ - - next_batch_id = random_string(8) - insertion_event = { - "type": EventTypes.MSC2716_INSERTION, - "sender": sender, - "room_id": room_id, - "content": { - EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id, - EventContentFields.MSC2716_HISTORICAL: True, - }, - "origin_server_ts": origin_server_ts, - } - - return insertion_event - - async def _create_requester_for_user_id_from_app_service( - self, user_id: str, app_service: ApplicationService - ) -> Requester: - """Creates a new requester for the given user_id - and validates that the app service is allowed to control - the given user. - - Args: - user_id: The author MXID that the app service is controlling - app_service: The app service that controls the user - - Returns: - Requester object - """ - - await self.auth.validate_appservice_can_control_user_id(app_service, user_id) - - return create_requester(user_id, app_service=app_service) - async def on_POST( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: @@ -200,123 +109,62 @@ class RoomBatchSendEventRestServlet(RestServlet): errcode=Codes.MISSING_PARAM, ) + # Verify the batch_id_from_query corresponds to an actual insertion event + # and have the batch connected. + if batch_id_from_query: + corresponding_insertion_event_id = ( + await self.store.get_insertion_event_by_batch_id( + room_id, batch_id_from_query + ) + ) + if corresponding_insertion_event_id is None: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "No insertion event corresponds to the given ?batch_id", + errcode=Codes.INVALID_PARAM, + ) + # For the event we are inserting next to (`prev_event_ids_from_query`), # find the most recent auth events (derived from state events) that # allowed that message to be sent. We will use that as a base # to auth our historical messages against. - ( - most_recent_prev_event_id, - _, - ) = await self.store.get_max_depth_of(prev_event_ids_from_query) - # mapping from (type, state_key) -> state_event_id - prev_state_map = await self.state_store.get_state_ids_for_event( - most_recent_prev_event_id + auth_event_ids = await self.room_batch_handler.get_most_recent_auth_event_ids_from_event_id_list( + prev_event_ids_from_query ) - # List of state event ID's - prev_state_ids = list(prev_state_map.values()) - auth_event_ids = prev_state_ids - - state_event_ids_at_start = [] - for state_event in body["state_events_at_start"]: - assert_params_in_dict( - state_event, ["type", "origin_server_ts", "content", "sender"] - ) - logger.debug( - "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", - state_event, - auth_event_ids, + # Create and persist all of the state events that float off on their own + # before the batch. These will most likely be all of the invite/member + # state events used to auth the upcoming historical messages. + state_event_ids_at_start = ( + await self.room_batch_handler.persist_state_events_at_start( + state_events_at_start=body["state_events_at_start"], + room_id=room_id, + initial_auth_event_ids=auth_event_ids, + app_service_requester=requester, ) + ) + # Update our ongoing auth event ID list with all of the new state we + # just created + auth_event_ids.extend(state_event_ids_at_start) - event_dict = { - "type": state_event["type"], - "origin_server_ts": state_event["origin_server_ts"], - "content": state_event["content"], - "room_id": room_id, - "sender": state_event["sender"], - "state_key": state_event["state_key"], - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - # Make the state events float off on their own - fake_prev_event_id = "$" + random_string(43) - - # TODO: This is pretty much the same as some other code to handle inserting state in this file - if event_dict["type"] == EventTypes.Member: - membership = event_dict["content"].get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - target=UserID.from_string(event_dict["state_key"]), - room_id=room_id, - action=membership, - content=event_dict["content"], - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - else: - # TODO: Add some complement tests that adds state that is not member joins - # and will use this code path. Maybe we only want to support join state events - # and can get rid of this `else`? - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - event_dict, - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - event_id = event.event_id - - state_event_ids_at_start.append(event_id) - auth_event_ids.append(event_id) - - events_to_create = body["events"] - - inherited_depth = await self._inherit_depth_from_prev_ids( + inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids( prev_event_ids_from_query ) + events_to_create = body["events"] + # Figure out which batch to connect to. If they passed in # batch_id_from_query let's use it. The batch ID passed in comes # from the batch_id in the "insertion" event from the previous batch. last_event_in_batch = events_to_create[-1] - batch_id_to_connect_to = batch_id_from_query base_insertion_event = None if batch_id_from_query: + batch_id_to_connect_to = batch_id_from_query # All but the first base insertion event should point at a fake # event, which causes the HS to ask for the state at the start of # the batch later. + fake_prev_event_id = "$" + random_string(43) prev_event_ids = [fake_prev_event_id] - - # Verify the batch_id_from_query corresponds to an actual insertion event - # and have the batch connected. - corresponding_insertion_event_id = ( - await self.store.get_insertion_event_by_batch_id( - room_id, batch_id_from_query - ) - ) - if corresponding_insertion_event_id is None: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "No insertion event corresponds to the given ?batch_id", - errcode=Codes.INVALID_PARAM, - ) - pass # Otherwise, create an insertion event to act as a starting point. # # We don't always have an insertion event to start hanging more history @@ -327,10 +175,12 @@ class RoomBatchSendEventRestServlet(RestServlet): else: prev_event_ids = prev_event_ids_from_query - base_insertion_event_dict = self._create_insertion_event_dict( - sender=requester.user.to_string(), - room_id=room_id, - origin_server_ts=last_event_in_batch["origin_server_ts"], + base_insertion_event_dict = ( + self.room_batch_handler.create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + origin_server_ts=last_event_in_batch["origin_server_ts"], + ) ) base_insertion_event_dict["prev_events"] = prev_event_ids.copy() @@ -338,7 +188,7 @@ class RoomBatchSendEventRestServlet(RestServlet): base_insertion_event, _, ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( + await self.room_batch_handler.create_requester_for_user_id_from_app_service( base_insertion_event_dict["sender"], requester.app_service, ), @@ -353,92 +203,17 @@ class RoomBatchSendEventRestServlet(RestServlet): EventContentFields.MSC2716_NEXT_BATCH_ID ] - # Connect this current batch to the insertion event from the previous batch - batch_event = { - "type": EventTypes.MSC2716_BATCH, - "sender": requester.user.to_string(), - "room_id": room_id, - "content": { - EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to, - EventContentFields.MSC2716_HISTORICAL: True, - }, - # Since the batch event is put at the end of the batch, - # where the newest-in-time event is, copy the origin_server_ts from - # the last event we're inserting - "origin_server_ts": last_event_in_batch["origin_server_ts"], - } - # Add the batch event to the end of the batch (newest-in-time) - events_to_create.append(batch_event) - - # Add an "insertion" event to the start of each batch (next to the oldest-in-time - # event in the batch) so the next batch can be connected to this one. - insertion_event = self._create_insertion_event_dict( - sender=requester.user.to_string(), + # Create and persist all of the historical events as well as insertion + # and batch meta events to make the batch navigable in the DAG. + event_ids, next_batch_id = await self.room_batch_handler.handle_batch_of_events( + events_to_create=events_to_create, room_id=room_id, - # Since the insertion event is put at the start of the batch, - # where the oldest-in-time event is, copy the origin_server_ts from - # the first event we're inserting - origin_server_ts=events_to_create[0]["origin_server_ts"], + batch_id_to_connect_to=batch_id_to_connect_to, + initial_prev_event_ids=prev_event_ids, + inherited_depth=inherited_depth, + auth_event_ids=auth_event_ids, + app_service_requester=requester, ) - # Prepend the insertion event to the start of the batch (oldest-in-time) - events_to_create = [insertion_event] + events_to_create - - event_ids = [] - events_to_persist = [] - for ev in events_to_create: - assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) - - event_dict = { - "type": ev["type"], - "origin_server_ts": ev["origin_server_ts"], - "content": ev["content"], - "room_id": room_id, - "sender": ev["sender"], # requester.user.to_string(), - "prev_events": prev_event_ids.copy(), - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - event, context = await self.event_creation_handler.create_event( - await self._create_requester_for_user_id_from_app_service( - ev["sender"], requester.app_service - ), - event_dict, - prev_event_ids=event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, - historical=True, - depth=inherited_depth, - ) - logger.debug( - "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", - event, - prev_event_ids, - auth_event_ids, - ) - - assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( - event.sender, - ) - - events_to_persist.append((event, context)) - event_id = event.event_id - - event_ids.append(event_id) - prev_event_ids = [event_id] - - # Persist events in reverse-chronological order so they have the - # correct stream_ordering as they are backfilled (which decrements). - # Events are sorted by (topological_ordering, stream_ordering) - # where topological_ordering is just depth. - for (event, context) in reversed(events_to_persist): - ev = await self.event_creation_handler.handle_new_client_event( - await self._create_requester_for_user_id_from_app_service( - event["sender"], requester.app_service - ), - event=event, - context=context, - ) insertion_event_id = event_ids[0] batch_event_id = event_ids[-1] @@ -447,9 +222,7 @@ class RoomBatchSendEventRestServlet(RestServlet): response_dict = { "state_event_ids": state_event_ids_at_start, "event_ids": historical_event_ids, - "next_batch_id": insertion_event["content"][ - EventContentFields.MSC2716_NEXT_BATCH_ID - ], + "next_batch_id": next_batch_id, "insertion_event_id": insertion_event_id, "batch_event_id": batch_event_id, } diff --git a/synapse/server.py b/synapse/server.py index 0783df41d4..5bc045d615 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -97,6 +97,7 @@ from synapse.handlers.room import ( RoomCreationHandler, RoomShutdownHandler, ) +from synapse.handlers.room_batch import RoomBatchHandler from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler @@ -437,6 +438,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_room_creation_handler(self) -> RoomCreationHandler: return RoomCreationHandler(self) + @cache_in_self + def get_room_batch_handler(self) -> RoomBatchHandler: + return RoomBatchHandler(self) + @cache_in_self def get_room_shutdown_handler(self) -> RoomShutdownHandler: return RoomShutdownHandler(self) -- cgit 1.5.1