From a19d01c3d95f5dbd3a4bb181cb70dacd44135a8b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 9 Nov 2021 08:10:58 -0500 Subject: Support filtering by relations per MSC3440 (#11236) Adds experimental support for `relation_types` and `relation_senders` fields for filters. --- synapse/handlers/pagination.py | 2 +- synapse/handlers/room.py | 8 +++++--- synapse/handlers/search.py | 8 +++++--- synapse/handlers/sync.py | 18 ++++++++++-------- 4 files changed, 21 insertions(+), 15 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index abfe7be0e3..aa26911aed 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -424,7 +424,7 @@ class PaginationHandler: if events: if event_filter: - events = event_filter.filter(events) + events = await event_filter.filter(events) events = await filter_events_for_client( self.storage, user_id, events, is_peeking=(member_event_id is None) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 969eb3b9b0..7a95e31cbe 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1158,8 +1158,10 @@ class RoomContextHandler: ) if event_filter: - results["events_before"] = event_filter.filter(results["events_before"]) - results["events_after"] = event_filter.filter(results["events_after"]) + results["events_before"] = await event_filter.filter( + results["events_before"] + ) + results["events_after"] = await event_filter.filter(results["events_after"]) results["events_before"] = await filter_evts(results["events_before"]) results["events_after"] = await filter_evts(results["events_after"]) @@ -1195,7 +1197,7 @@ class RoomContextHandler: state_events = list(state[last_event_id].values()) if event_filter: - state_events = event_filter.filter(state_events) + state_events = await event_filter.filter(state_events) results["state"] = await filter_evts(state_events) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 6e4dff8056..ab7eaab2fb 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -180,7 +180,7 @@ class SearchHandler: % (set(group_keys) - {"room_id", "sender"},), ) - search_filter = Filter(filter_dict) + search_filter = Filter(self.hs, filter_dict) # TODO: Search through left rooms too rooms = await self.store.get_rooms_for_local_user_where_membership_is( @@ -242,7 +242,7 @@ class SearchHandler: rank_map.update({r["event"].event_id: r["rank"] for r in results}) - filtered_events = search_filter.filter([r["event"] for r in results]) + filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( self.storage, user.to_string(), filtered_events @@ -292,7 +292,9 @@ class SearchHandler: rank_map.update({r["event"].event_id: r["rank"] for r in results}) - filtered_events = search_filter.filter([r["event"] for r in results]) + filtered_events = await search_filter.filter( + [r["event"] for r in results] + ) events = await filter_events_for_client( self.storage, user.to_string(), filtered_events diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 2c7c6d63a9..891435c14d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -510,7 +510,7 @@ class SyncHandler: log_kv({"limited": limited}) if potential_recents: - recents = sync_config.filter_collection.filter_room_timeline( + recents = await sync_config.filter_collection.filter_room_timeline( potential_recents ) log_kv({"recents_after_sync_filtering": len(recents)}) @@ -575,8 +575,8 @@ class SyncHandler: log_kv({"loaded_recents": len(events)}) - loaded_recents = sync_config.filter_collection.filter_room_timeline( - events + loaded_recents = ( + await sync_config.filter_collection.filter_room_timeline(events) ) log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)}) @@ -1015,7 +1015,7 @@ class SyncHandler: return { (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state( + for e in await sync_config.filter_collection.filter_room_state( list(state.values()) ) if e.type != EventTypes.Aliases # until MSC2261 or alternative solution @@ -1383,7 +1383,7 @@ class SyncHandler: sync_config.user ) - account_data_for_user = sync_config.filter_collection.filter_account_data( + account_data_for_user = await sync_config.filter_collection.filter_account_data( [ {"type": account_data_type, "content": content} for account_data_type, content in account_data.items() @@ -1448,7 +1448,7 @@ class SyncHandler: # Deduplicate the presence entries so that there's at most one per user presence = list({p.user_id: p for p in presence}.values()) - presence = sync_config.filter_collection.filter_presence(presence) + presence = await sync_config.filter_collection.filter_presence(presence) sync_result_builder.presence = presence @@ -2021,12 +2021,14 @@ class SyncHandler: ) account_data_events = ( - sync_config.filter_collection.filter_room_account_data( + await sync_config.filter_collection.filter_room_account_data( account_data_events ) ) - ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral) + ephemeral = await sync_config.filter_collection.filter_room_ephemeral( + ephemeral + ) if not ( always_include -- cgit 1.5.1 From b6f4d122efb86e3fc44e358cf573dc2caa6ff634 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 9 Nov 2021 13:11:47 +0000 Subject: Allow admins to proactively block rooms (#11228) Co-authored-by: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/11228.feature | 1 + docs/admin_api/rooms.md | 16 +++++++---- synapse/handlers/room.py | 51 ++++++++++++++++++++++++++-------- synapse/rest/admin/rooms.py | 21 +++++++++++--- synapse/storage/databases/main/room.py | 7 ++++- tests/rest/admin/test_room.py | 28 +++++++++++++++++++ 6 files changed, 103 insertions(+), 21 deletions(-) create mode 100644 changelog.d/11228.feature (limited to 'synapse/handlers') diff --git a/changelog.d/11228.feature b/changelog.d/11228.feature new file mode 100644 index 0000000000..33c1756b50 --- /dev/null +++ b/changelog.d/11228.feature @@ -0,0 +1 @@ +Allow the admin [Delete Room API](https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#delete-room-api) to block a room without the need to join it. diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index ab6b82a082..41a4961d00 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -396,13 +396,17 @@ The new room will be created with the user specified by the `new_room_user_id` p as room administrator and will contain a message explaining what happened. Users invited to the new room will have power level `-10` by default, and thus be unable to speak. -If `block` is `True` it prevents new joins to the old room. +If `block` is `true`, users will be prevented from joining the old room. +This option can also be used to pre-emptively block a room, even if it's unknown +to this homeserver. In this case, the room will be blocked, and no further action +will be taken. If `block` is `false`, attempting to delete an unknown room is +invalid and will be rejected as a bad request. This API will remove all trace of the old room from your database after removing all local users. If `purge` is `true` (the default), all traces of the old room will be removed from your database after removing all local users. If you do not want this to happen, set `purge` to `false`. -Depending on the amount of history being purged a call to the API may take +Depending on the amount of history being purged, a call to the API may take several minutes or longer. The local server will only have the power to move local user and room aliases to @@ -464,8 +468,9 @@ The following JSON body parameters are available: `new_room_user_id` in the new room. Ideally this will clearly convey why the original room was shut down. Defaults to `Sharing illegal content on this server is not permitted and rooms in violation will be blocked.` -* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing - future attempts to join the room. Defaults to `false`. +* `block` - Optional. If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Rooms can be blocked + even if they're not yet known to the homeserver. Defaults to `false`. * `purge` - Optional. If set to `true`, it will remove all traces of the room from your database. Defaults to `true`. * `force_purge` - Optional, and ignored unless `purge` is `true`. If set to `true`, it @@ -483,7 +488,8 @@ The following fields are returned in the JSON response body: * `failed_to_kick_users` - An array of users (`user_id`) that that were not kicked. * `local_aliases` - An array of strings representing the local aliases that were migrated from the old room to the new. -* `new_room_id` - A string representing the room ID of the new room. +* `new_room_id` - A string representing the room ID of the new room, or `null` if + no such room was created. ## Undoing room deletions diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7a95e31cbe..11af30eee7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Contains functions for performing events on rooms.""" - +"""Contains functions for performing actions on rooms.""" import itertools import logging import math @@ -31,6 +30,8 @@ from typing import ( Tuple, ) +from typing_extensions import TypedDict + from synapse.api.constants import ( EventContentFields, EventTypes, @@ -1277,6 +1278,13 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): return self.store.get_room_events_max_id(room_id) +class ShutdownRoomResponse(TypedDict): + kicked_users: List[str] + failed_to_kick_users: List[str] + local_aliases: List[str] + new_room_id: Optional[str] + + class RoomShutdownHandler: DEFAULT_MESSAGE = ( @@ -1302,7 +1310,7 @@ class RoomShutdownHandler: new_room_name: Optional[str] = None, message: Optional[str] = None, block: bool = False, - ) -> dict: + ) -> ShutdownRoomResponse: """ Shuts down a room. Moves all local users and room aliases automatically to a new room if `new_room_user_id` is set. Otherwise local users only @@ -1336,8 +1344,13 @@ class RoomShutdownHandler: Defaults to `Sharing illegal content on this server is not permitted and rooms in violation will be blocked.` block: - If set to `true`, this room will be added to a blocking list, - preventing future attempts to join the room. Defaults to `false`. + If set to `True`, users will be prevented from joining the old + room. This option can also be used to pre-emptively block a room, + even if it's unknown to this homeserver. In this case, the room + will be blocked, and no further action will be taken. If `False`, + attempting to delete an unknown room is invalid. + + Defaults to `False`. Returns: a dict containing the following keys: kicked_users: An array of users (`user_id`) that were kicked. @@ -1346,7 +1359,9 @@ class RoomShutdownHandler: local_aliases: An array of strings representing the local aliases that were migrated from the old room to the new. - new_room_id: A string representing the room ID of the new room. + new_room_id: + A string representing the room ID of the new room, or None if + no such room was created. """ if not new_room_name: @@ -1357,14 +1372,28 @@ class RoomShutdownHandler: if not RoomID.is_valid(room_id): raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) - if not await self.store.get_room(room_id): - raise NotFoundError("Unknown room id %s" % (room_id,)) - - # This will work even if the room is already blocked, but that is - # desirable in case the first attempt at blocking the room failed below. + # Action the block first (even if the room doesn't exist yet) if block: + # This will work even if the room is already blocked, but that is + # desirable in case the first attempt at blocking the room failed below. await self.store.block_room(room_id, requester_user_id) + if not await self.store.get_room(room_id): + if block: + # We allow you to block an unknown room. + return { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } + else: + # But if you don't want to preventatively block another room, + # this function can't do anything useful. + raise NotFoundError( + "Cannot shut down room: unknown room id %s" % (room_id,) + ) + if new_room_user_id is not None: if not self.hs.is_mine_id(new_room_user_id): raise SynapseError( diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 05c823d9ce..a2f4edebb8 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from http import HTTPStatus -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, cast from urllib import parse as urlparse from synapse.api.constants import EventTypes, JoinRules, Membership @@ -239,9 +239,22 @@ class RoomRestServlet(RestServlet): # Purge room if purge: - await pagination_handler.purge_room(room_id, force=force_purge) - - return 200, ret + try: + await pagination_handler.purge_room(room_id, force=force_purge) + except NotFoundError: + if block: + # We can block unknown rooms with this endpoint, in which case + # a failed purge is expected. + pass + else: + # But otherwise, we expect this purge to have succeeded. + raise + + # Cast safety: cast away the knowledge that this is a TypedDict. + # See https://github.com/python/mypy/issues/4976#issuecomment-579883622 + # for some discussion on why this is necessary. Either way, + # `ret` is an opaque dictionary blob as far as the rest of the app cares. + return 200, cast(JsonDict, ret) class RoomMembersRestServlet(RestServlet): diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index cefc77fa0f..17b398bb69 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1751,7 +1751,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) async def block_room(self, room_id: str, user_id: str) -> None: - """Marks the room as blocked. Can be called multiple times. + """Marks the room as blocked. + + Can be called multiple times (though we'll only track the last user to + block this room). + + Can be called on a room unknown to this homeserver. Args: room_id: Room to block diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 46116644ce..11ec54c82e 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -14,9 +14,12 @@ import json import urllib.parse +from http import HTTPStatus from typing import List, Optional from unittest.mock import Mock +from parameterized import parameterized + import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes @@ -281,6 +284,31 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._is_blocked(self.room_id, expect=True) self._has_no_members(self.room_id) + @parameterized.expand([(True,), (False,)]) + def test_block_unknown_room(self, purge: bool) -> None: + """ + We can block an unknown room. In this case, the `purge` argument + should be ignored. + """ + room_id = "!unknown:test" + + # The room isn't already in the blocked rooms table + self._is_blocked(room_id, expect=False) + + # Request the room be blocked. + channel = self.make_request( + "DELETE", + f"/_synapse/admin/v1/rooms/{room_id}", + {"block": True, "purge": purge}, + access_token=self.admin_user_tok, + ) + + # The room is now blocked. + self.assertEqual( + HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"] + ) + self._is_blocked(room_id) + def test_shutdown_room_consent(self): """Test that we can shutdown rooms with local users who have not yet accepted the privacy policy. This used to fail when we tried to -- cgit 1.5.1 From a026695083d66f9fc67a892b59eee66f039e2038 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 9 Nov 2021 14:31:15 +0000 Subject: Clarifications and small fixes to to-device related code (#11247) Co-authored-by: Patrick Cloke --- changelog.d/11247.misc | 1 + synapse/handlers/appservice.py | 24 +++++++++++++++++---- synapse/handlers/devicemessage.py | 31 +++++++++++++++++++++++---- synapse/storage/databases/main/appservice.py | 8 +++---- synapse/storage/databases/main/deviceinbox.py | 23 +++++++++++++++++--- tests/handlers/test_appservice.py | 8 +++++-- 6 files changed, 78 insertions(+), 17 deletions(-) create mode 100644 changelog.d/11247.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11247.misc b/changelog.d/11247.misc new file mode 100644 index 0000000000..5ce701560e --- /dev/null +++ b/changelog.d/11247.misc @@ -0,0 +1 @@ +Clean up code relating to to-device messages and sending ephemeral events to application services. \ No newline at end of file diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index ddc9105ee9..9abdad262b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -188,7 +188,7 @@ class ApplicationServicesHandler: self, stream_key: str, new_token: Union[int, RoomStreamToken], - users: Optional[Collection[Union[str, UserID]]] = None, + users: Collection[Union[str, UserID]], ) -> None: """ This is called by the notifier in the background when an ephemeral event is handled @@ -203,7 +203,9 @@ class ApplicationServicesHandler: value for `stream_key` will cause this function to return early. Ephemeral events will only be pushed to appservices that have opted into - them. + receiving them by setting `push_ephemeral` to true in their registration + file. Note that while MSC2409 is experimental, this option is called + `de.sorunome.msc2409.push_ephemeral`. Appservices will only receive ephemeral events that fall within their registered user and room namespaces. @@ -214,6 +216,7 @@ class ApplicationServicesHandler: if not self.notify_appservices: return + # Ignore any unsupported streams if stream_key not in ("typing_key", "receipt_key", "presence_key"): return @@ -230,18 +233,25 @@ class ApplicationServicesHandler: # Additional context: https://github.com/matrix-org/synapse/pull/11137 assert isinstance(new_token, int) + # Check whether there are any appservices which have registered to receive + # ephemeral events. + # + # Note that whether these events are actually relevant to these appservices + # is decided later on. services = [ service for service in self.store.get_app_services() if service.supports_ephemeral ] if not services: + # Bail out early if none of the target appservices have explicitly registered + # to receive these ephemeral events. return # We only start a new background process if necessary rather than # optimistically (to cut down on overhead). self._notify_interested_services_ephemeral( - services, stream_key, new_token, users or [] + services, stream_key, new_token, users ) @wrap_as_background_process("notify_interested_services_ephemeral") @@ -252,7 +262,7 @@ class ApplicationServicesHandler: new_token: int, users: Collection[Union[str, UserID]], ) -> None: - logger.debug("Checking interested services for %s" % (stream_key)) + logger.debug("Checking interested services for %s", stream_key) with Measure(self.clock, "notify_interested_services_ephemeral"): for service in services: if stream_key == "typing_key": @@ -345,6 +355,9 @@ class ApplicationServicesHandler: Args: service: The application service to check for which events it should receive. + new_token: A receipts event stream token. Purely used to double-check that the + from_token we pull from the database isn't greater than or equal to this + token. Prevents accidentally duplicating work. Returns: A list of JSON dictionaries containing data derived from the read receipts that @@ -382,6 +395,9 @@ class ApplicationServicesHandler: Args: service: The application service that ephemeral events are being sent to. users: The users that should receive the presence update. + new_token: A presence update stream token. Purely used to double-check that the + from_token we pull from the database isn't greater than or equal to this + token. Prevents accidentally duplicating work. Returns: A list of json dictionaries containing data derived from the presence events diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index b6a2a34ab7..b582266af9 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -89,6 +89,13 @@ class DeviceMessageHandler: ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: + """ + Handle receiving to-device messages from remote homeservers. + + Args: + origin: The remote homeserver. + content: The JSON dictionary containing the to-device messages. + """ local_messages = {} sender_user_id = content["sender"] if origin != get_domain_from_id(sender_user_id): @@ -135,12 +142,16 @@ class DeviceMessageHandler: message_type, sender_user_id, by_device ) - stream_id = await self.store.add_messages_from_remote_to_device_inbox( + # Add messages to the database. + # Retrieve the stream id of the last-processed to-device message. + last_stream_id = await self.store.add_messages_from_remote_to_device_inbox( origin, message_id, local_messages ) + # Notify listeners that there are new to-device messages to process, + # handing them the latest stream id. self.notifier.on_new_event( - "to_device_key", stream_id, users=local_messages.keys() + "to_device_key", last_stream_id, users=local_messages.keys() ) async def _check_for_unknown_devices( @@ -195,6 +206,14 @@ class DeviceMessageHandler: message_type: str, messages: Dict[str, Dict[str, JsonDict]], ) -> None: + """ + Handle a request from a user to send to-device message(s). + + Args: + requester: The user that is sending the to-device messages. + message_type: The type of to-device messages that are being sent. + messages: A dictionary containing recipients mapped to messages intended for them. + """ sender_user_id = requester.user.to_string() message_id = random_string(16) @@ -257,12 +276,16 @@ class DeviceMessageHandler: "org.matrix.opentracing_context": json_encoder.encode(context), } - stream_id = await self.store.add_messages_to_device_inbox( + # Add messages to the database. + # Retrieve the stream id of the last-processed to-device message. + last_stream_id = await self.store.add_messages_to_device_inbox( local_messages, remote_edu_contents ) + # Notify listeners that there are new to-device messages to process, + # handing them the latest stream id. self.notifier.on_new_event( - "to_device_key", stream_id, users=local_messages.keys() + "to_device_key", last_stream_id, users=local_messages.keys() ) if self.federation_sender: diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 2da2659f41..baec35ee27 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -412,16 +412,16 @@ class ApplicationServiceTransactionWorkerStore( ) async def set_type_stream_id_for_appservice( - self, service: ApplicationService, type: str, pos: Optional[int] + self, service: ApplicationService, stream_type: str, pos: Optional[int] ) -> None: - if type not in ("read_receipt", "presence"): + if stream_type not in ("read_receipt", "presence"): raise ValueError( "Expected type to be a valid application stream id type, got %s" - % (type,) + % (stream_type,) ) def set_type_stream_id_for_appservice_txn(txn): - stream_id_type = "%s_stream_id" % type + stream_id_type = "%s_stream_id" % stream_type txn.execute( "UPDATE application_services_state SET %s = ? WHERE as_id=?" % stream_id_type, diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 264e625bd7..ae3afdd5d2 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -134,7 +134,10 @@ class DeviceInboxWorkerStore(SQLBaseStore): limit: The maximum number of messages to retrieve. Returns: - A list of messages for the device and where in the stream the messages got to. + A tuple containing: + * A list of messages for the device. + * The max stream token of these messages. There may be more to retrieve + if the given limit was reached. """ has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_stream_id @@ -153,12 +156,19 @@ class DeviceInboxWorkerStore(SQLBaseStore): txn.execute( sql, (user_id, device_id, last_stream_id, current_stream_id, limit) ) + messages = [] + stream_pos = current_stream_id + for row in txn: stream_pos = row[0] messages.append(db_to_json(row[1])) + + # If the limit was not reached we know that there's no more data for this + # user/device pair up to current_stream_id. if len(messages) < limit: stream_pos = current_stream_id + return messages, stream_pos return await self.db_pool.runInteraction( @@ -260,13 +270,20 @@ class DeviceInboxWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (destination, last_stream_id, current_stream_id, limit)) + messages = [] + stream_pos = current_stream_id + for row in txn: stream_pos = row[0] messages.append(db_to_json(row[1])) + + # If the limit was not reached we know that there's no more data for this + # user/device pair up to current_stream_id. if len(messages) < limit: log_kv({"message": "Set stream position to current position"}) stream_pos = current_stream_id + return messages, stream_pos return await self.db_pool.runInteraction( @@ -372,8 +389,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): """Used to send messages from this server. Args: - local_messages_by_user_and_device: - Dictionary of user_id to device_id to message. + local_messages_by_user_then_device: + Dictionary of recipient user_id to recipient device_id to message. remote_messages_by_destination: Dictionary of destination server_name to the EDU JSON to send. diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 1f6a924452..d6f14e2dba 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -272,7 +272,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): make_awaitable(([event], None)) ) - self.handler.notify_interested_services_ephemeral("receipt_key", 580) + self.handler.notify_interested_services_ephemeral( + "receipt_key", 580, ["@fakerecipient:example.com"] + ) self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with( interested_service, [event] ) @@ -300,7 +302,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): make_awaitable(([event], None)) ) - self.handler.notify_interested_services_ephemeral("receipt_key", 579) + self.handler.notify_interested_services_ephemeral( + "receipt_key", 580, ["@fakerecipient:example.com"] + ) self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called() def _mkservice(self, is_interested, protocols=None): -- cgit 1.5.1 From 73cbb284b932c8b769e8a42ccda0f239a0fb1a71 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 10 Nov 2021 14:16:06 +0000 Subject: Remove redundant parameters on `_check_event_auth` (#11292) as of #11012, these parameters are unused. --- changelog.d/11292.misc | 1 + synapse/handlers/federation_event.py | 10 ---------- tests/test_federation.py | 2 -- 3 files changed, 1 insertion(+), 12 deletions(-) create mode 100644 changelog.d/11292.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11292.misc b/changelog.d/11292.misc new file mode 100644 index 0000000000..d1b76b1574 --- /dev/null +++ b/changelog.d/11292.misc @@ -0,0 +1 @@ +Remove unused parameters on `FederationEventHandler._check_event_auth`. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 1a1cd93b1a..9917613298 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -981,8 +981,6 @@ class FederationEventHandler: origin, event, context, - state=state, - backfilled=backfilled, ) except AuthError as e: # FIXME richvdh 2021/10/07 I don't think this is reachable. Let's log it @@ -1332,8 +1330,6 @@ class FederationEventHandler: origin: str, event: EventBase, context: EventContext, - state: Optional[Iterable[EventBase]] = None, - backfilled: bool = False, ) -> EventContext: """ Checks whether an event should be rejected (for failing auth checks). @@ -1344,12 +1340,6 @@ class FederationEventHandler: context: The event context. - state: - The state events used to check the event for soft-fail. If this is - not provided the current state events will be used. - - backfilled: True if the event was backfilled. - Returns: The updated context object. diff --git a/tests/test_federation.py b/tests/test_federation.py index 24fc77d7a7..3eef1c4c05 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -81,8 +81,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase): origin, event, context, - state=None, - backfilled=False, ): return context -- cgit 1.5.1 From 5cace20bf1f3c50ea50d56a938c4eee5825f4b84 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 10 Nov 2021 15:06:54 -0500 Subject: Add missing type hints to `synapse.app`. (#11287) --- changelog.d/11287.misc | 1 + mypy.ini | 19 +---- synapse/app/__init__.py | 11 +-- synapse/app/_base.py | 140 +++++++++++++++++++++++------------- synapse/app/admin_cmd.py | 42 ++++++----- synapse/app/generic_worker.py | 34 +++++---- synapse/app/homeserver.py | 92 ++++++++++-------------- synapse/app/phone_stats_home.py | 23 +++--- synapse/handlers/admin.py | 4 +- synapse/logging/handlers.py | 4 +- synapse/module_api/__init__.py | 4 +- synapse/replication/tcp/resource.py | 4 +- synapse/server.py | 15 ++-- synapse/types.py | 2 + synapse/util/async_helpers.py | 5 +- synapse/util/httpresourcetree.py | 6 +- synapse/util/manhole.py | 7 +- 17 files changed, 223 insertions(+), 190 deletions(-) create mode 100644 changelog.d/11287.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11287.misc b/changelog.d/11287.misc new file mode 100644 index 0000000000..26ec3cb657 --- /dev/null +++ b/changelog.d/11287.misc @@ -0,0 +1 @@ +Add missing type hints to `synapse.app`. diff --git a/mypy.ini b/mypy.ini index 04e8eab72b..989116a1a8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -23,22 +23,6 @@ files = # https://docs.python.org/3/library/re.html#re.X exclude = (?x) ^( - |synapse/app/__init__.py - |synapse/app/_base.py - |synapse/app/admin_cmd.py - |synapse/app/appservice.py - |synapse/app/client_reader.py - |synapse/app/event_creator.py - |synapse/app/federation_reader.py - |synapse/app/federation_sender.py - |synapse/app/frontend_proxy.py - |synapse/app/generic_worker.py - |synapse/app/homeserver.py - |synapse/app/media_repository.py - |synapse/app/phone_stats_home.py - |synapse/app/pusher.py - |synapse/app/synchrotron.py - |synapse/app/user_dir.py |synapse/storage/databases/__init__.py |synapse/storage/databases/main/__init__.py |synapse/storage/databases/main/account_data.py @@ -179,6 +163,9 @@ exclude = (?x) [mypy-synapse.api.*] disallow_untyped_defs = True +[mypy-synapse.app.*] +disallow_untyped_defs = True + [mypy-synapse.crypto.*] disallow_untyped_defs = True diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index f9940491e8..ee51480a9e 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import sys +from typing import Container from synapse import python_dependencies # noqa: E402 @@ -27,7 +28,9 @@ except python_dependencies.DependencyException as e: sys.exit(1) -def check_bind_error(e, address, bind_addresses): +def check_bind_error( + e: Exception, address: str, bind_addresses: Container[str] +) -> None: """ This method checks an exception occurred while binding on 0.0.0.0. If :: is specified in the bind addresses a warning is shown. @@ -38,9 +41,9 @@ def check_bind_error(e, address, bind_addresses): When binding on 0.0.0.0 after :: this can safely be ignored. Args: - e (Exception): Exception that was caught. - address (str): Address on which binding was attempted. - bind_addresses (list): Addresses on which the service listens. + e: Exception that was caught. + address: Address on which binding was attempted. + bind_addresses: Addresses on which the service listens. """ if address == "0.0.0.0" and "::" in bind_addresses: logger.warning( diff --git a/synapse/app/_base.py b/synapse/app/_base.py index f2c1028b5d..573bb487b2 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -22,13 +22,27 @@ import socket import sys import traceback import warnings -from typing import TYPE_CHECKING, Awaitable, Callable, Iterable +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Collection, + Dict, + Iterable, + List, + NoReturn, + Tuple, + cast, +) from cryptography.utils import CryptographyDeprecationWarning -from typing_extensions import NoReturn import twisted -from twisted.internet import defer, error, reactor +from twisted.internet import defer, error, reactor as _reactor +from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorSSL, IReactorTCP +from twisted.internet.protocol import ServerFactory +from twisted.internet.tcp import Port from twisted.logger import LoggingFile, LogLevel from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.python.threadpool import ThreadPool @@ -48,6 +62,7 @@ from synapse.logging.context import PreserveLoggingContext from synapse.metrics import register_threadpool from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats +from synapse.types import ISynapseReactor from synapse.util.caches.lrucache import setup_expire_lru_cache_entries from synapse.util.daemonize import daemonize_process from synapse.util.gai_resolver import GAIResolver @@ -57,33 +72,44 @@ from synapse.util.versionstring import get_version_string if TYPE_CHECKING: from synapse.server import HomeServer +# Twisted injects the global reactor to make it easier to import, this confuses +# mypy which thinks it is a module. Tell it that it a more proper type. +reactor = cast(ISynapseReactor, _reactor) + + logger = logging.getLogger(__name__) # list of tuples of function, args list, kwargs dict -_sighup_callbacks = [] +_sighup_callbacks: List[ + Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]] +] = [] -def register_sighup(func, *args, **kwargs): +def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None: """ Register a function to be called when a SIGHUP occurs. Args: - func (function): Function to be called when sent a SIGHUP signal. + func: Function to be called when sent a SIGHUP signal. *args, **kwargs: args and kwargs to be passed to the target function. """ _sighup_callbacks.append((func, args, kwargs)) -def start_worker_reactor(appname, config, run_command=reactor.run): +def start_worker_reactor( + appname: str, + config: HomeServerConfig, + run_command: Callable[[], None] = reactor.run, +) -> None: """Run the reactor in the main process Daemonizes if necessary, and then configures some resources, before starting the reactor. Pulls configuration from the 'worker' settings in 'config'. Args: - appname (str): application name which will be sent to syslog - config (synapse.config.Config): config object - run_command (Callable[]): callable that actually runs the reactor + appname: application name which will be sent to syslog + config: config object + run_command: callable that actually runs the reactor """ logger = logging.getLogger(config.worker.worker_app) @@ -101,32 +127,32 @@ def start_worker_reactor(appname, config, run_command=reactor.run): def start_reactor( - appname, - soft_file_limit, - gc_thresholds, - pid_file, - daemonize, - print_pidfile, - logger, - run_command=reactor.run, -): + appname: str, + soft_file_limit: int, + gc_thresholds: Tuple[int, int, int], + pid_file: str, + daemonize: bool, + print_pidfile: bool, + logger: logging.Logger, + run_command: Callable[[], None] = reactor.run, +) -> None: """Run the reactor in the main process Daemonizes if necessary, and then configures some resources, before starting the reactor Args: - appname (str): application name which will be sent to syslog - soft_file_limit (int): + appname: application name which will be sent to syslog + soft_file_limit: gc_thresholds: - pid_file (str): name of pid file to write to if daemonize is True - daemonize (bool): true to run the reactor in a background process - print_pidfile (bool): whether to print the pid file, if daemonize is True - logger (logging.Logger): logger instance to pass to Daemonize - run_command (Callable[]): callable that actually runs the reactor + pid_file: name of pid file to write to if daemonize is True + daemonize: true to run the reactor in a background process + print_pidfile: whether to print the pid file, if daemonize is True + logger: logger instance to pass to Daemonize + run_command: callable that actually runs the reactor """ - def run(): + def run() -> None: logger.info("Running") setup_jemalloc_stats() change_resource_limit(soft_file_limit) @@ -185,7 +211,7 @@ def redirect_stdio_to_logs() -> None: print("Redirected stdout/stderr to logs") -def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None: +def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None: """Register a callback with the reactor, to be called once it is running This can be used to initialise parts of the system which require an asynchronous @@ -195,7 +221,7 @@ def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None: will exit. """ - async def wrapper(): + async def wrapper() -> None: try: await cb(*args, **kwargs) except Exception: @@ -224,7 +250,7 @@ def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None: reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) -def listen_metrics(bind_addresses, port): +def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: """ Start Prometheus metrics server. """ @@ -236,11 +262,11 @@ def listen_metrics(bind_addresses, port): def listen_manhole( - bind_addresses: Iterable[str], + bind_addresses: Collection[str], port: int, manhole_settings: ManholeConfig, manhole_globals: dict, -): +) -> None: # twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing # warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so # suppress the warning for now. @@ -259,12 +285,18 @@ def listen_manhole( ) -def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): +def listen_tcp( + bind_addresses: Collection[str], + port: int, + factory: ServerFactory, + reactor: IReactorTCP = reactor, + backlog: int = 50, +) -> List[Port]: """ Create a TCP socket for a port and several addresses Returns: - list[twisted.internet.tcp.Port]: listening for TCP connections + list of twisted.internet.tcp.Port listening for TCP connections """ r = [] for address in bind_addresses: @@ -273,12 +305,19 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): except error.CannotListenError as e: check_bind_error(e, address, bind_addresses) - return r + # IReactorTCP returns an object implementing IListeningPort from listenTCP, + # but we know it will be a Port instance. + return r # type: ignore[return-value] def listen_ssl( - bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50 -): + bind_addresses: Collection[str], + port: int, + factory: ServerFactory, + context_factory: IOpenSSLContextFactory, + reactor: IReactorSSL = reactor, + backlog: int = 50, +) -> List[Port]: """ Create an TLS-over-TCP socket for a port and several addresses @@ -294,10 +333,13 @@ def listen_ssl( except error.CannotListenError as e: check_bind_error(e, address, bind_addresses) - return r + # IReactorSSL incorrectly declares that an int is returned from listenSSL, + # it actually returns an object implementing IListeningPort, but we know it + # will be a Port instance. + return r # type: ignore[return-value] -def refresh_certificate(hs: "HomeServer"): +def refresh_certificate(hs: "HomeServer") -> None: """ Refresh the TLS certificates that Synapse is using by re-reading them from disk and updating the TLS context factories to use them. @@ -329,7 +371,7 @@ def refresh_certificate(hs: "HomeServer"): logger.info("Context factories updated.") -async def start(hs: "HomeServer"): +async def start(hs: "HomeServer") -> None: """ Start a Synapse server or worker. @@ -360,7 +402,7 @@ async def start(hs: "HomeServer"): if hasattr(signal, "SIGHUP"): @wrap_as_background_process("sighup") - def handle_sighup(*args, **kwargs): + def handle_sighup(*args: Any, **kwargs: Any) -> None: # Tell systemd our state, if we're using it. This will silently fail if # we're not using systemd. sdnotify(b"RELOADING=1") @@ -373,7 +415,7 @@ async def start(hs: "HomeServer"): # We defer running the sighup handlers until next reactor tick. This # is so that we're in a sane state, e.g. flushing the logs may fail # if the sighup happens in the middle of writing a log entry. - def run_sighup(*args, **kwargs): + def run_sighup(*args: Any, **kwargs: Any) -> None: # `callFromThread` should be "signal safe" as well as thread # safe. reactor.callFromThread(handle_sighup, *args, **kwargs) @@ -436,12 +478,8 @@ async def start(hs: "HomeServer"): atexit.register(gc.freeze) -def setup_sentry(hs: "HomeServer"): - """Enable sentry integration, if enabled in configuration - - Args: - hs - """ +def setup_sentry(hs: "HomeServer") -> None: + """Enable sentry integration, if enabled in configuration""" if not hs.config.metrics.sentry_enabled: return @@ -466,7 +504,7 @@ def setup_sentry(hs: "HomeServer"): scope.set_tag("worker_name", name) -def setup_sdnotify(hs: "HomeServer"): +def setup_sdnotify(hs: "HomeServer") -> None: """Adds process state hooks to tell systemd what we are up to.""" # Tell systemd our state, if we're using it. This will silently fail if @@ -481,7 +519,7 @@ def setup_sdnotify(hs: "HomeServer"): sdnotify_sockaddr = os.getenv("NOTIFY_SOCKET") -def sdnotify(state): +def sdnotify(state: bytes) -> None: """ Send a notification to systemd, if the NOTIFY_SOCKET env var is set. @@ -490,7 +528,7 @@ def sdnotify(state): package which many OSes don't include as a matter of principle. Args: - state (bytes): notification to send + state: notification to send """ if not isinstance(state, bytes): raise TypeError("sdnotify should be called with a bytes") diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index ad20b1d6aa..42238f7f28 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -17,6 +17,7 @@ import logging import os import sys import tempfile +from typing import List, Optional from twisted.internet import defer, task @@ -25,6 +26,7 @@ from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging +from synapse.events import EventBase from synapse.handlers.admin import ExfiltrationWriter from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore @@ -40,6 +42,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.server import HomeServer from synapse.storage.databases.main.room import RoomWorkerStore +from synapse.types import StateMap from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -65,16 +68,11 @@ class AdminCmdSlavedStore( class AdminCmdServer(HomeServer): - DATASTORE_CLASS = AdminCmdSlavedStore + DATASTORE_CLASS = AdminCmdSlavedStore # type: ignore -async def export_data_command(hs: HomeServer, args): - """Export data for a user. - - Args: - hs - args (argparse.Namespace) - """ +async def export_data_command(hs: HomeServer, args: argparse.Namespace) -> None: + """Export data for a user.""" user_id = args.user_id directory = args.output_directory @@ -92,12 +90,12 @@ class FileExfiltrationWriter(ExfiltrationWriter): Note: This writes to disk on the main reactor thread. Args: - user_id (str): The user whose data is being exfiltrated. - directory (str|None): The directory to write the data to, if None then - will write to a temporary directory. + user_id: The user whose data is being exfiltrated. + directory: The directory to write the data to, if None then will write + to a temporary directory. """ - def __init__(self, user_id, directory=None): + def __init__(self, user_id: str, directory: Optional[str] = None): self.user_id = user_id if directory: @@ -111,7 +109,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): if list(os.listdir(self.base_directory)): raise Exception("Directory must be empty") - def write_events(self, room_id, events): + def write_events(self, room_id: str, events: List[EventBase]) -> None: room_directory = os.path.join(self.base_directory, "rooms", room_id) os.makedirs(room_directory, exist_ok=True) events_file = os.path.join(room_directory, "events") @@ -120,7 +118,9 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in events: print(json.dumps(event.get_pdu_json()), file=f) - def write_state(self, room_id, event_id, state): + def write_state( + self, room_id: str, event_id: str, state: StateMap[EventBase] + ) -> None: room_directory = os.path.join(self.base_directory, "rooms", room_id) state_directory = os.path.join(room_directory, "state") os.makedirs(state_directory, exist_ok=True) @@ -131,7 +131,9 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in state.values(): print(json.dumps(event.get_pdu_json()), file=f) - def write_invite(self, room_id, event, state): + def write_invite( + self, room_id: str, event: EventBase, state: StateMap[EventBase] + ) -> None: self.write_events(room_id, [event]) # We write the invite state somewhere else as they aren't full events @@ -145,7 +147,9 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in state.values(): print(json.dumps(event), file=f) - def write_knock(self, room_id, event, state): + def write_knock( + self, room_id: str, event: EventBase, state: StateMap[EventBase] + ) -> None: self.write_events(room_id, [event]) # We write the knock state somewhere else as they aren't full events @@ -159,11 +163,11 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in state.values(): print(json.dumps(event), file=f) - def finished(self): + def finished(self) -> str: return self.base_directory -def start(config_options): +def start(config_options: List[str]) -> None: parser = argparse.ArgumentParser(description="Synapse Admin Command") HomeServerConfig.add_arguments_to_parser(parser) @@ -231,7 +235,7 @@ def start(config_options): # We also make sure that `_base.start` gets run before we actually run the # command. - async def run(): + async def run() -> None: with LoggingContext("command"): await _base.start(ss) await args.func(ss, args) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 218826741e..46f0feff70 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -14,11 +14,10 @@ # limitations under the License. import logging import sys -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple from twisted.internet import address -from twisted.web.resource import IResource -from twisted.web.server import Request +from twisted.web.resource import Resource import synapse import synapse.events @@ -44,7 +43,7 @@ from synapse.config.server import ListenerConfig from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource, OptionsResource from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.http.site import SynapseSite +from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource @@ -119,6 +118,7 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore +from synapse.types import JsonDict from synapse.util.httpresourcetree import create_resource_tree from synapse.util.versionstring import get_version_string @@ -143,7 +143,9 @@ class KeyUploadServlet(RestServlet): self.http_client = hs.get_simple_http_client() self.main_uri = hs.config.worker.worker_main_http_uri - async def on_POST(self, request: Request, device_id: Optional[str]): + async def on_POST( + self, request: SynapseRequest, device_id: Optional[str] + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -187,9 +189,8 @@ class KeyUploadServlet(RestServlet): # If the header exists, add to the comma-separated list of the first # instance of the header. Otherwise, generate a new header. if x_forwarded_for: - x_forwarded_for = [ - x_forwarded_for[0] + b", " + previous_host - ] + x_forwarded_for[1:] + x_forwarded_for = [x_forwarded_for[0] + b", " + previous_host] + x_forwarded_for.extend(x_forwarded_for[1:]) else: x_forwarded_for = [previous_host] headers[b"X-Forwarded-For"] = x_forwarded_for @@ -253,13 +254,16 @@ class GenericWorkerSlavedStore( SessionStore, BaseSlavedStore, ): - pass + # Properties that multiple storage classes define. Tell mypy what the + # expected type is. + server_name: str + config: HomeServerConfig class GenericWorkerServer(HomeServer): - DATASTORE_CLASS = GenericWorkerSlavedStore + DATASTORE_CLASS = GenericWorkerSlavedStore # type: ignore - def _listen_http(self, listener_config: ListenerConfig): + def _listen_http(self, listener_config: ListenerConfig) -> None: port = listener_config.port bind_addresses = listener_config.bind_addresses @@ -267,10 +271,10 @@ class GenericWorkerServer(HomeServer): site_tag = listener_config.http_options.tag if site_tag is None: - site_tag = port + site_tag = str(port) # We always include a health resource. - resources: Dict[str, IResource] = {"/health": HealthResource()} + resources: Dict[str, Resource] = {"/health": HealthResource()} for res in listener_config.http_options.resources: for name in res.names: @@ -386,7 +390,7 @@ class GenericWorkerServer(HomeServer): logger.info("Synapse worker now listening on port %d", port) - def start_listening(self): + def start_listening(self) -> None: for listener in self.config.worker.worker_listeners: if listener.type == "http": self._listen_http(listener) @@ -411,7 +415,7 @@ class GenericWorkerServer(HomeServer): self.get_tcp_replication().start_replication(self) -def start(config_options): +def start(config_options: List[str]) -> None: try: config = HomeServerConfig.load_config("Synapse worker", config_options) except ConfigError as e: diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 336c279a44..7bb3744f04 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -16,10 +16,10 @@ import logging import os import sys -from typing import Iterator +from typing import Dict, Iterable, Iterator, List -from twisted.internet import reactor -from twisted.web.resource import EncodingResourceWrapper, IResource +from twisted.internet.tcp import Port +from twisted.web.resource import EncodingResourceWrapper, Resource from twisted.web.server import GzipEncoderFactory from twisted.web.static import File @@ -76,23 +76,27 @@ from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.homeserver") -def gz_wrap(r): +def gz_wrap(r: Resource) -> Resource: return EncodingResourceWrapper(r, [GzipEncoderFactory()]) class SynapseHomeServer(HomeServer): - DATASTORE_CLASS = DataStore + DATASTORE_CLASS = DataStore # type: ignore - def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig): + def _listener_http( + self, config: HomeServerConfig, listener_config: ListenerConfig + ) -> Iterable[Port]: port = listener_config.port bind_addresses = listener_config.bind_addresses tls = listener_config.tls + # Must exist since this is an HTTP listener. + assert listener_config.http_options is not None site_tag = listener_config.http_options.tag if site_tag is None: site_tag = str(port) # We always include a health resource. - resources = {"/health": HealthResource()} + resources: Dict[str, Resource] = {"/health": HealthResource()} for res in listener_config.http_options.resources: for name in res.names: @@ -111,7 +115,7 @@ class SynapseHomeServer(HomeServer): ("listeners", site_tag, "additional_resources", "<%s>" % (path,)), ) handler = handler_cls(config, module_api) - if IResource.providedBy(handler): + if isinstance(handler, Resource): resource = handler elif hasattr(handler, "handle_request"): resource = AdditionalResource(self, handler.handle_request) @@ -128,7 +132,7 @@ class SynapseHomeServer(HomeServer): # try to find something useful to redirect '/' to if WEB_CLIENT_PREFIX in resources: - root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX) + root_resource: Resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX) elif STATIC_PREFIX in resources: root_resource = RootOptionsRedirectResource(STATIC_PREFIX) else: @@ -145,6 +149,8 @@ class SynapseHomeServer(HomeServer): ) if tls: + # refresh_certificate should have been called before this. + assert self.tls_server_context_factory is not None ports = listen_ssl( bind_addresses, port, @@ -165,20 +171,21 @@ class SynapseHomeServer(HomeServer): return ports - def _configure_named_resource(self, name, compress=False): + def _configure_named_resource( + self, name: str, compress: bool = False + ) -> Dict[str, Resource]: """Build a resource map for a named resource Args: - name (str): named resource: one of "client", "federation", etc - compress (bool): whether to enable gzip compression for this - resource + name: named resource: one of "client", "federation", etc + compress: whether to enable gzip compression for this resource Returns: - dict[str, Resource]: map from path to HTTP resource + map from path to HTTP resource """ - resources = {} + resources: Dict[str, Resource] = {} if name == "client": - client_resource = ClientRestResource(self) + client_resource: Resource = ClientRestResource(self) if compress: client_resource = gz_wrap(client_resource) @@ -207,7 +214,7 @@ class SynapseHomeServer(HomeServer): if name == "consent": from synapse.rest.consent.consent_resource import ConsentResource - consent_resource = ConsentResource(self) + consent_resource: Resource = ConsentResource(self) if compress: consent_resource = gz_wrap(consent_resource) resources.update({"/_matrix/consent": consent_resource}) @@ -277,7 +284,7 @@ class SynapseHomeServer(HomeServer): return resources - def start_listening(self): + def start_listening(self) -> None: if self.config.redis.redis_enabled: # If redis is enabled we connect via the replication command handler # in the same way as the workers (since we're effectively a client @@ -303,7 +310,9 @@ class SynapseHomeServer(HomeServer): ReplicationStreamProtocolFactory(self), ) for s in services: - reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) + self.get_reactor().addSystemEventTrigger( + "before", "shutdown", s.stopListening + ) elif listener.type == "metrics": if not self.config.metrics.enable_metrics: logger.warning( @@ -318,14 +327,13 @@ class SynapseHomeServer(HomeServer): logger.warning("Unrecognized listener type: %s", listener.type) -def setup(config_options): +def setup(config_options: List[str]) -> SynapseHomeServer: """ Args: - config_options_options: The options passed to Synapse. Usually - `sys.argv[1:]`. + config_options_options: The options passed to Synapse. Usually `sys.argv[1:]`. Returns: - HomeServer + A homeserver instance. """ try: config = HomeServerConfig.load_or_generate_config( @@ -364,7 +372,7 @@ def setup(config_options): except Exception as e: handle_startup_exception(e) - async def start(): + async def start() -> None: # Load the OIDC provider metadatas, if OIDC is enabled. if hs.config.oidc.oidc_enabled: oidc = hs.get_oidc_handler() @@ -404,39 +412,15 @@ def format_config_error(e: ConfigError) -> Iterator[str]: yield ":\n %s" % (e.msg,) - e = e.__cause__ + parent_e = e.__cause__ indent = 1 - while e: + while parent_e: indent += 1 - yield ":\n%s%s" % (" " * indent, str(e)) - e = e.__cause__ - - -def run(hs: HomeServer): - PROFILE_SYNAPSE = False - if PROFILE_SYNAPSE: - - def profile(func): - from cProfile import Profile - from threading import current_thread - - def profiled(*args, **kargs): - profile = Profile() - profile.enable() - func(*args, **kargs) - profile.disable() - ident = current_thread().ident - profile.dump_stats( - "/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident) - ) - - return profiled - - from twisted.python.threadpool import ThreadPool + yield ":\n%s%s" % (" " * indent, str(parent_e)) + parent_e = parent_e.__cause__ - ThreadPool._worker = profile(ThreadPool._worker) - reactor.run = profile(reactor.run) +def run(hs: HomeServer) -> None: _base.start_reactor( "synapse-homeserver", soft_file_limit=hs.config.server.soft_file_limit, @@ -448,7 +432,7 @@ def run(hs: HomeServer): ) -def main(): +def main() -> None: with LoggingContext("main"): # check base requirements check_requirements() diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 126450e17a..899dba5c3d 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -15,11 +15,12 @@ import logging import math import resource import sys -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Sized, Tuple from prometheus_client import Gauge from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -28,7 +29,7 @@ logger = logging.getLogger("synapse.app.homeserver") # Contains the list of processes we will be monitoring # currently either 0 or 1 -_stats_process = [] +_stats_process: List[Tuple[int, "resource.struct_rusage"]] = [] # Gauges to expose monthly active user control metrics current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") @@ -45,9 +46,15 @@ registered_reserved_users_mau_gauge = Gauge( @wrap_as_background_process("phone_stats_home") -async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process): +async def phone_stats_home( + hs: "HomeServer", + stats: JsonDict, + stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process, +) -> None: logger.info("Gathering stats for reporting") now = int(hs.get_clock().time()) + # Ensure the homeserver has started. + assert hs.start_time is not None uptime = int(now - hs.start_time) if uptime < 0: uptime = 0 @@ -146,15 +153,15 @@ async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process logger.warning("Error reporting stats: %s", e) -def start_phone_stats_home(hs: "HomeServer"): +def start_phone_stats_home(hs: "HomeServer") -> None: """ Start the background tasks which report phone home stats. """ clock = hs.get_clock() - stats = {} + stats: JsonDict = {} - def performance_stats_init(): + def performance_stats_init() -> None: _stats_process.clear() _stats_process.append( (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF)) @@ -170,10 +177,10 @@ def start_phone_stats_home(hs: "HomeServer"): hs.get_datastore().reap_monthly_active_users() @wrap_as_background_process("generate_monthly_active_users") - async def generate_monthly_active_users(): + async def generate_monthly_active_users() -> None: current_mau_count = 0 current_mau_count_by_service = {} - reserved_users = () + reserved_users: Sized = () store = hs.get_datastore() if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: current_mau_count = await store.get_monthly_active_count() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index be3203ac80..85157a138b 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -234,7 +234,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): @abc.abstractmethod def write_invite( - self, room_id: str, event: EventBase, state: StateMap[dict] + self, room_id: str, event: EventBase, state: StateMap[EventBase] ) -> None: """Write an invite for the room, with associated invite state. @@ -248,7 +248,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): @abc.abstractmethod def write_knock( - self, room_id: str, event: EventBase, state: StateMap[dict] + self, room_id: str, event: EventBase, state: StateMap[EventBase] ) -> None: """Write a knock for the room, with associated knock state. diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py index af5fc407a8..478b527494 100644 --- a/synapse/logging/handlers.py +++ b/synapse/logging/handlers.py @@ -3,7 +3,7 @@ import time from logging import Handler, LogRecord from logging.handlers import MemoryHandler from threading import Thread -from typing import Optional +from typing import Optional, cast from twisted.internet.interfaces import IReactorCore @@ -56,7 +56,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler): if reactor is None: from twisted.internet import reactor as global_reactor - reactor_to_use = global_reactor # type: ignore[assignment] + reactor_to_use = cast(IReactorCore, global_reactor) else: reactor_to_use = reactor diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6e7f5238fe..ff79bc3c11 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -31,7 +31,7 @@ import attr import jinja2 from twisted.internet import defer -from twisted.web.resource import IResource +from twisted.web.resource import Resource from synapse.api.errors import SynapseError from synapse.events import EventBase @@ -196,7 +196,7 @@ class ModuleApi: """ return self._password_auth_provider.register_password_auth_provider_callbacks - def register_web_resource(self, path: str, resource: IResource): + def register_web_resource(self, path: str, resource: Resource): """Registers a web resource to be served at the given path. This function should be called during initialisation of the module. diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 55326877fd..a9d85f4f6c 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING from prometheus_client import Counter -from twisted.internet.protocol import Factory +from twisted.internet.protocol import ServerFactory from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import PositionCommand @@ -38,7 +38,7 @@ stream_updates_counter = Counter( logger = logging.getLogger(__name__) -class ReplicationStreamProtocolFactory(Factory): +class ReplicationStreamProtocolFactory(ServerFactory): """Factory for new replication connections.""" def __init__(self, hs: "HomeServer"): diff --git a/synapse/server.py b/synapse/server.py index 013a7bacaa..877eba6c08 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -33,9 +33,10 @@ from typing import ( cast, ) -import twisted.internet.tcp +from twisted.internet.interfaces import IOpenSSLContextFactory +from twisted.internet.tcp import Port from twisted.web.iweb import IPolicyForHTTPS -from twisted.web.resource import IResource +from twisted.web.resource import Resource from synapse.api.auth import Auth from synapse.api.filtering import Filtering @@ -206,7 +207,7 @@ class HomeServer(metaclass=abc.ABCMeta): Attributes: config (synapse.config.homeserver.HomeserverConfig): - _listening_services (list[twisted.internet.tcp.Port]): TCP ports that + _listening_services (list[Port]): TCP ports that we are listening on to provide HTTP services. """ @@ -225,6 +226,8 @@ class HomeServer(metaclass=abc.ABCMeta): # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() + tls_server_context_factory: Optional[IOpenSSLContextFactory] + def __init__( self, hostname: str, @@ -247,7 +250,7 @@ class HomeServer(metaclass=abc.ABCMeta): # the key we use to sign events and requests self.signing_key = config.key.signing_key[0] self.config = config - self._listening_services: List[twisted.internet.tcp.Port] = [] + self._listening_services: List[Port] = [] self.start_time: Optional[int] = None self._instance_id = random_string(5) @@ -257,10 +260,10 @@ class HomeServer(metaclass=abc.ABCMeta): self.datastores: Optional[Databases] = None - self._module_web_resources: Dict[str, IResource] = {} + self._module_web_resources: Dict[str, Resource] = {} self._module_web_resources_consumed = False - def register_module_web_resource(self, path: str, resource: IResource): + def register_module_web_resource(self, path: str, resource: Resource): """Allows a module to register a web resource to be served at the given path. If multiple modules register a resource for the same path, the module that diff --git a/synapse/types.py b/synapse/types.py index 364ecf7d45..9d7a675662 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -38,6 +38,7 @@ from zope.interface import Interface from twisted.internet.interfaces import ( IReactorCore, IReactorPluggableNameResolver, + IReactorSSL, IReactorTCP, IReactorThreads, IReactorTime, @@ -66,6 +67,7 @@ JsonDict = Dict[str, Any] # for mypy-zope to realize it is an interface. class ISynapseReactor( IReactorTCP, + IReactorSSL, IReactorPluggableNameResolver, IReactorTime, IReactorCore, diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 96efc5f3e3..561b962e14 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -31,13 +31,13 @@ from typing import ( Set, TypeVar, Union, + cast, ) import attr from typing_extensions import ContextManager from twisted.internet import defer -from twisted.internet.base import ReactorBase from twisted.internet.defer import CancelledError from twisted.internet.interfaces import IReactorTime from twisted.python import failure @@ -271,8 +271,7 @@ class Linearizer: if not clock: from twisted.internet import reactor - assert isinstance(reactor, ReactorBase) - clock = Clock(reactor) + clock = Clock(cast(IReactorTime, reactor)) self._clock = clock self.max_count = max_count diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index b163643ca3..a0606851f7 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -92,9 +92,9 @@ def _resource_id(resource: Resource, path_seg: bytes) -> str: the mapping should looks like _resource_id(A,C) = B. Args: - resource (Resource): The *parent* Resourceb - path_seg (str): The name of the child Resource to be attached. + resource: The *parent* Resourceb + path_seg: The name of the child Resource to be attached. Returns: - str: A unique string which can be a key to the child Resource. + A unique string which can be a key to the child Resource. """ return "%s-%r" % (resource, path_seg) diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index f8b2d7bea9..48b8195ca1 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -23,7 +23,7 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter from twisted.conch.ssh.keys import Key from twisted.cred import checkers, portal from twisted.internet import defer -from twisted.internet.protocol import Factory +from twisted.internet.protocol import ServerFactory from synapse.config.server import ManholeConfig @@ -65,7 +65,7 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs= -----END RSA PRIVATE KEY-----""" -def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory: +def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory: """Starts a ssh listener with password authentication using the given username and password. Clients connecting to the ssh listener will find themselves in a colored python shell with @@ -105,7 +105,8 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory: factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment] factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment] - return factory + # ConchFactory is a Factory, not a ServerFactory, but they are identical. + return factory # type: ignore[return-value] class SynapseManhole(ColoredManhole): -- cgit 1.5.1 From 8840a7b7f1074073c49135d13918d9e4d4a04577 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 12 Nov 2021 13:35:31 +0100 Subject: Convert delete room admin API to async endpoint (#11223) Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/11223.feature | 1 + docs/admin_api/purge_history_api.md | 2 + docs/admin_api/rooms.md | 189 +++++++++- synapse/handlers/pagination.py | 289 +++++++++++++- synapse/handlers/room.py | 13 +- synapse/rest/admin/__init__.py | 6 + synapse/rest/admin/rooms.py | 134 ++++++- tests/rest/admin/test_admin.py | 48 +++ tests/rest/admin/test_room.py | 726 ++++++++++++++++++++++++++++++++---- 9 files changed, 1317 insertions(+), 91 deletions(-) create mode 100644 changelog.d/11223.feature (limited to 'synapse/handlers') diff --git a/changelog.d/11223.feature b/changelog.d/11223.feature new file mode 100644 index 0000000000..55ea693dcd --- /dev/null +++ b/changelog.d/11223.feature @@ -0,0 +1 @@ +Add a new version of delete room admin API `DELETE /_synapse/admin/v2/rooms/` to run it in background. Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/admin_api/purge_history_api.md b/docs/admin_api/purge_history_api.md index bd29e29ab8..277e28d9cb 100644 --- a/docs/admin_api/purge_history_api.md +++ b/docs/admin_api/purge_history_api.md @@ -70,6 +70,8 @@ This API returns a JSON body like the following: The status will be one of `active`, `complete`, or `failed`. +If `status` is `failed` there will be a string `error` with the error message. + ## Reclaim disk space (Postgres) To reclaim the disk space and return it to the operating system, you need to run diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 41a4961d00..6a6ae92d66 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -4,6 +4,9 @@ - [Room Members API](#room-members-api) - [Room State API](#room-state-api) - [Delete Room API](#delete-room-api) + * [Version 1 (old version)](#version-1-old-version) + * [Version 2 (new version)](#version-2-new-version) + * [Status of deleting rooms](#status-of-deleting-rooms) * [Undoing room shutdowns](#undoing-room-shutdowns) - [Make Room Admin API](#make-room-admin-api) - [Forward Extremities Admin API](#forward-extremities-admin-api) @@ -397,10 +400,10 @@ as room administrator and will contain a message explaining what happened. Users to the new room will have power level `-10` by default, and thus be unable to speak. If `block` is `true`, users will be prevented from joining the old room. -This option can also be used to pre-emptively block a room, even if it's unknown -to this homeserver. In this case, the room will be blocked, and no further action -will be taken. If `block` is `false`, attempting to delete an unknown room is -invalid and will be rejected as a bad request. +This option can in [Version 1](#version-1-old-version) also be used to pre-emptively +block a room, even if it's unknown to this homeserver. In this case, the room will be +blocked, and no further action will be taken. If `block` is `false`, attempting to +delete an unknown room is invalid and will be rejected as a bad request. This API will remove all trace of the old room from your database after removing all local users. If `purge` is `true` (the default), all traces of the old room will @@ -412,6 +415,17 @@ several minutes or longer. The local server will only have the power to move local user and room aliases to the new room. Users on other servers will be unaffected. +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see [Admin API](../usage/administration/admin_api). + +## Version 1 (old version) + +This version works synchronously. That means you only get the response once the server has +finished the action, which may take a long time. If you request the same action +a second time, and the server has not finished the first one, the second request will block. +This is fixed in version 2 of this API. The parameters are the same in both APIs. +This API will become deprecated in the future. + The API is: ``` @@ -430,9 +444,6 @@ with a body of: } ``` -To use it, you will need to authenticate by providing an ``access_token`` for a -server admin: see [Admin API](../usage/administration/admin_api). - A response body like the following is returned: ```json @@ -449,6 +460,44 @@ A response body like the following is returned: } ``` +The parameters and response values have the same format as +[version 2](#version-2-new-version) of the API. + +## Version 2 (new version) + +**Note**: This API is new, experimental and "subject to change". + +This version works asynchronously, meaning you get the response from server immediately +while the server works on that task in background. You can then request the status of the action +to check if it has completed. + +The API is: + +``` +DELETE /_synapse/admin/v2/rooms/ +``` + +with a body of: + +```json +{ + "new_room_user_id": "@someuser:example.com", + "room_name": "Content Violation Notification", + "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service.", + "block": true, + "purge": true +} +``` + +The API starts the shut down and purge running, and returns immediately with a JSON body with +a purge id: + +```json +{ + "delete_id": "" +} +``` + **Parameters** The following parameters should be set in the URL: @@ -470,7 +519,8 @@ The following JSON body parameters are available: is not permitted and rooms in violation will be blocked.` * `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing future attempts to join the room. Rooms can be blocked - even if they're not yet known to the homeserver. Defaults to `false`. + even if they're not yet known to the homeserver (only with + [Version 1](#version-1-old-version) of the API). Defaults to `false`. * `purge` - Optional. If set to `true`, it will remove all traces of the room from your database. Defaults to `true`. * `force_purge` - Optional, and ignored unless `purge` is `true`. If set to `true`, it @@ -480,17 +530,124 @@ The following JSON body parameters are available: The JSON body must not be empty. The body must be at least `{}`. -**Response** +## Status of deleting rooms -The following fields are returned in the JSON response body: +**Note**: This API is new, experimental and "subject to change". + +It is possible to query the status of the background task for deleting rooms. +The status can be queried up to 24 hours after completion of the task, +or until Synapse is restarted (whichever happens first). + +### Query by `room_id` -* `kicked_users` - An array of users (`user_id`) that were kicked. -* `failed_to_kick_users` - An array of users (`user_id`) that that were not kicked. -* `local_aliases` - An array of strings representing the local aliases that were migrated from - the old room to the new. -* `new_room_id` - A string representing the room ID of the new room, or `null` if - no such room was created. +With this API you can get the status of all active deletion tasks, and all those completed in the last 24h, +for the given `room_id`. + +The API is: + +``` +GET /_synapse/admin/v2/rooms//delete_status +``` + +A response body like the following is returned: + +```json +{ + "results": [ + { + "delete_id": "delete_id1", + "status": "failed", + "error": "error message", + "shutdown_room": { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": null + } + }, { + "delete_id": "delete_id2", + "status": "purging", + "shutdown_room": { + "kicked_users": [ + "@foobar:example.com" + ], + "failed_to_kick_users": [], + "local_aliases": [ + "#badroom:example.com", + "#evilsaloon:example.com" + ], + "new_room_id": "!newroomid:example.com" + } + } + ] +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +* `room_id` - The ID of the room. + +### Query by `delete_id` + +With this API you can get the status of one specific task by `delete_id`. + +The API is: + +``` +GET /_synapse/admin/v2/rooms/delete_status/ +``` + +A response body like the following is returned: + +```json +{ + "status": "purging", + "shutdown_room": { + "kicked_users": [ + "@foobar:example.com" + ], + "failed_to_kick_users": [], + "local_aliases": [ + "#badroom:example.com", + "#evilsaloon:example.com" + ], + "new_room_id": "!newroomid:example.com" + } +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +* `delete_id` - The ID for this delete. + +### Response + +The following fields are returned in the JSON response body: +- `results` - An array of objects, each containing information about one task. + This field is omitted from the result when you query by `delete_id`. + Task objects contain the following fields: + - `delete_id` - The ID for this purge if you query by `room_id`. + - `status` - The status will be one of: + - `shutting_down` - The process is removing users from the room. + - `purging` - The process is purging the room and event data from database. + - `complete` - The process has completed successfully. + - `failed` - The process is aborted, an error has occurred. + - `error` - A string that shows an error message if `status` is `failed`. + Otherwise this field is hidden. + - `shutdown_room` - An object containing information about the result of shutting down the room. + *Note:* The result is shown after removing the room members. + The delete process can still be running. Please pay attention to the `status`. + - `kicked_users` - An array of users (`user_id`) that were kicked. + - `failed_to_kick_users` - An array of users (`user_id`) that that were not kicked. + - `local_aliases` - An array of strings representing the local aliases that were + migrated from the old room to the new. + - `new_room_id` - A string representing the room ID of the new room, or `null` if + no such room was created. ## Undoing room deletions diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index aa26911aed..cd64142735 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Optional, Set +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set import attr @@ -22,7 +22,7 @@ from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter -from synapse.logging.context import run_in_background +from synapse.handlers.room import ShutdownRoomResponse from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig @@ -56,11 +56,62 @@ class PurgeStatus: STATUS_FAILED: "failed", } + # Save the error message if an error occurs + error: str = "" + # Tracks whether this request has completed. One of STATUS_{ACTIVE,COMPLETE,FAILED}. status: int = STATUS_ACTIVE def asdict(self) -> JsonDict: - return {"status": PurgeStatus.STATUS_TEXT[self.status]} + ret = {"status": PurgeStatus.STATUS_TEXT[self.status]} + if self.error: + ret["error"] = self.error + return ret + + +@attr.s(slots=True, auto_attribs=True) +class DeleteStatus: + """Object tracking the status of a delete room request + + This class contains information on the progress of a delete room request, for + return by get_delete_status. + """ + + STATUS_PURGING = 0 + STATUS_COMPLETE = 1 + STATUS_FAILED = 2 + STATUS_SHUTTING_DOWN = 3 + + STATUS_TEXT = { + STATUS_PURGING: "purging", + STATUS_COMPLETE: "complete", + STATUS_FAILED: "failed", + STATUS_SHUTTING_DOWN: "shutting_down", + } + + # Tracks whether this request has completed. + # One of STATUS_{PURGING,COMPLETE,FAILED,SHUTTING_DOWN}. + status: int = STATUS_PURGING + + # Save the error message if an error occurs + error: str = "" + + # Saves the result of an action to give it back to REST API + shutdown_room: ShutdownRoomResponse = { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } + + def asdict(self) -> JsonDict: + ret = { + "status": DeleteStatus.STATUS_TEXT[self.status], + "shutdown_room": self.shutdown_room, + } + if self.error: + ret["error"] = self.error + return ret class PaginationHandler: @@ -70,6 +121,9 @@ class PaginationHandler: paginating during a purge. """ + # when to remove a completed deletion/purge from the results map + CLEAR_PURGE_AFTER_MS = 1000 * 3600 * 24 # 24 hours + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() @@ -78,11 +132,18 @@ class PaginationHandler: self.state_store = self.storage.state self.clock = hs.get_clock() self._server_name = hs.hostname + self._room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_lock = ReadWriteLock() + # IDs of rooms in which there currently an active purge *or delete* operation. self._purges_in_progress_by_room: Set[str] = set() # map from purge id to PurgeStatus self._purges_by_id: Dict[str, PurgeStatus] = {} + # map from purge id to DeleteStatus + self._delete_by_id: Dict[str, DeleteStatus] = {} + # map from room id to delete ids + # Dict[`room_id`, List[`delete_id`]] + self._delete_by_room: Dict[str, List[str]] = {} self._event_serializer = hs.get_event_client_serializer() self._retention_default_max_lifetime = ( @@ -265,8 +326,13 @@ class PaginationHandler: logger.info("[purge] starting purge_id %s", purge_id) self._purges_by_id[purge_id] = PurgeStatus() - run_in_background( - self._purge_history, purge_id, room_id, token, delete_local_events + run_as_background_process( + "purge_history", + self._purge_history, + purge_id, + room_id, + token, + delete_local_events, ) return purge_id @@ -276,7 +342,7 @@ class PaginationHandler: """Carry out a history purge on a room. Args: - purge_id: The id for this purge + purge_id: The ID for this purge. room_id: The room to purge from token: topological token to delete events before delete_local_events: True to delete local events as well as remote ones @@ -295,6 +361,7 @@ class PaginationHandler: "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore ) self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED + self._purges_by_id[purge_id].error = f.getErrorMessage() finally: self._purges_in_progress_by_room.discard(room_id) @@ -302,7 +369,9 @@ class PaginationHandler: def clear_purge() -> None: del self._purges_by_id[purge_id] - self.hs.get_reactor().callLater(24 * 3600, clear_purge) + self.hs.get_reactor().callLater( + PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000, clear_purge + ) def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]: """Get the current status of an active purge @@ -312,8 +381,25 @@ class PaginationHandler: """ return self._purges_by_id.get(purge_id) + def get_delete_status(self, delete_id: str) -> Optional[DeleteStatus]: + """Get the current status of an active deleting + + Args: + delete_id: delete_id returned by start_shutdown_and_purge_room + """ + return self._delete_by_id.get(delete_id) + + def get_delete_ids_by_room(self, room_id: str) -> Optional[Collection[str]]: + """Get all active delete ids by room + + Args: + room_id: room_id that is deleted + """ + return self._delete_by_room.get(room_id) + async def purge_room(self, room_id: str, force: bool = False) -> None: """Purge the given room from the database. + This function is part the delete room v1 API. Args: room_id: room to be purged @@ -472,3 +558,192 @@ class PaginationHandler: ) return chunk + + async def _shutdown_and_purge_room( + self, + delete_id: str, + room_id: str, + requester_user_id: str, + new_room_user_id: Optional[str] = None, + new_room_name: Optional[str] = None, + message: Optional[str] = None, + block: bool = False, + purge: bool = True, + force_purge: bool = False, + ) -> None: + """ + Shuts down and purges a room. + + See `RoomShutdownHandler.shutdown_room` for details of creation of the new room + + Args: + delete_id: The ID for this delete. + room_id: The ID of the room to shut down. + requester_user_id: + User who requested the action. Will be recorded as putting the room on the + blocking list. + new_room_user_id: + If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be + moved into that room. If not set, no new room will be created + and the users will just be removed from the old room. + new_room_name: + A string representing the name of the room that new users will + be invited to. Defaults to `Content Violation Notification` + message: + A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly + convey why the original room was shut down. + Defaults to `Sharing illegal content on this server is not + permitted and rooms in violation will be blocked.` + block: + If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Defaults to `false`. + purge: + If set to `true`, purge the given room from the database. + force_purge: + If set to `true`, the room will be purged from database + also if it fails to remove some users from room. + + Saves a `RoomShutdownHandler.ShutdownRoomResponse` in `DeleteStatus`: + """ + + self._purges_in_progress_by_room.add(room_id) + try: + with await self.pagination_lock.write(room_id): + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN + self._delete_by_id[ + delete_id + ].shutdown_room = await self._room_shutdown_handler.shutdown_room( + room_id=room_id, + requester_user_id=requester_user_id, + new_room_user_id=new_room_user_id, + new_room_name=new_room_name, + message=message, + block=block, + ) + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_PURGING + + if purge: + logger.info("starting purge room_id %s", room_id) + + # first check that we have no users in this room + if not force_purge: + joined = await self.store.is_host_joined( + room_id, self._server_name + ) + if joined: + raise SynapseError( + 400, "Users are still joined to this room" + ) + + await self.storage.purge_events.purge_room(room_id) + + logger.info("complete") + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE + except Exception: + f = Failure() + logger.error( + "failed", + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + ) + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_FAILED + self._delete_by_id[delete_id].error = f.getErrorMessage() + finally: + self._purges_in_progress_by_room.discard(room_id) + + # remove the delete from the list 24 hours after it completes + def clear_delete() -> None: + del self._delete_by_id[delete_id] + self._delete_by_room[room_id].remove(delete_id) + if not self._delete_by_room[room_id]: + del self._delete_by_room[room_id] + + self.hs.get_reactor().callLater( + PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000, clear_delete + ) + + def start_shutdown_and_purge_room( + self, + room_id: str, + requester_user_id: str, + new_room_user_id: Optional[str] = None, + new_room_name: Optional[str] = None, + message: Optional[str] = None, + block: bool = False, + purge: bool = True, + force_purge: bool = False, + ) -> str: + """Start off shut down and purge on a room. + + Args: + room_id: The ID of the room to shut down. + requester_user_id: + User who requested the action and put the room on the + blocking list. + new_room_user_id: + If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be + moved into that room. If not set, no new room will be created + and the users will just be removed from the old room. + new_room_name: + A string representing the name of the room that new users will + be invited to. Defaults to `Content Violation Notification` + message: + A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly + convey why the original room was shut down. + Defaults to `Sharing illegal content on this server is not + permitted and rooms in violation will be blocked.` + block: + If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Defaults to `false`. + purge: + If set to `true`, purge the given room from the database. + force_purge: + If set to `true`, the room will be purged from database + also if it fails to remove some users from room. + + Returns: + unique ID for this delete transaction. + """ + if room_id in self._purges_in_progress_by_room: + raise SynapseError( + 400, "History purge already in progress for %s" % (room_id,) + ) + + # This check is double to `RoomShutdownHandler.shutdown_room` + # But here the requester get a direct response / error with HTTP request + # and do not have to check the purge status + if new_room_user_id is not None: + if not self.hs.is_mine_id(new_room_user_id): + raise SynapseError( + 400, "User must be our own: %s" % (new_room_user_id,) + ) + + delete_id = random_string(16) + + # we log the delete_id here so that it can be tied back to the + # request id in the log lines. + logger.info( + "starting shutdown room_id %s with delete_id %s", + room_id, + delete_id, + ) + + self._delete_by_id[delete_id] = DeleteStatus() + self._delete_by_room.setdefault(room_id, []).append(delete_id) + run_as_background_process( + "shutdown_and_purge_room", + self._shutdown_and_purge_room, + delete_id, + room_id, + requester_user_id, + new_room_user_id, + new_room_name, + message, + block, + purge, + force_purge, + ) + return delete_id diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 11af30eee7..f9a099c4f3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1279,6 +1279,17 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): class ShutdownRoomResponse(TypedDict): + """ + Attributes: + kicked_users: An array of users (`user_id`) that were kicked. + failed_to_kick_users: + An array of users (`user_id`) that that were not kicked. + local_aliases: + An array of strings representing the local aliases that were + migrated from the old room to the new. + new_room_id: A string representing the room ID of the new room. + """ + kicked_users: List[str] failed_to_kick_users: List[str] local_aliases: List[str] @@ -1286,7 +1297,6 @@ class ShutdownRoomResponse(TypedDict): class RoomShutdownHandler: - DEFAULT_MESSAGE = ( "Sharing illegal content on this server is not permitted and rooms in" " violation will be blocked." @@ -1299,7 +1309,6 @@ class RoomShutdownHandler: self._room_creation_handler = hs.get_room_creation_handler() self._replication = hs.get_replication_data_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.state = hs.get_state_handler() self.store = hs.get_datastore() async def shutdown_room( diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 81e98f81d6..d78fe406c4 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -46,6 +46,8 @@ from synapse.rest.admin.registration_tokens import ( RegistrationTokenRestServlet, ) from synapse.rest.admin.rooms import ( + DeleteRoomStatusByDeleteIdRestServlet, + DeleteRoomStatusByRoomIdRestServlet, ForwardExtremitiesRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, @@ -53,6 +55,7 @@ from synapse.rest.admin.rooms import ( RoomEventContextServlet, RoomMembersRestServlet, RoomRestServlet, + RoomRestV2Servlet, RoomStateRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet @@ -223,7 +226,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListRoomRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) + RoomRestV2Servlet(hs).register(http_server) RoomMembersRestServlet(hs).register(http_server) + DeleteRoomStatusByDeleteIdRestServlet(hs).register(http_server) + DeleteRoomStatusByRoomIdRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index a2f4edebb8..37cb4d0796 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -34,7 +34,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, RoomID, UserID, create_requester from synapse.util import json_decoder if TYPE_CHECKING: @@ -46,6 +46,138 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class RoomRestV2Servlet(RestServlet): + """Delete a room from server asynchronously with a background task. + + It is a combination and improvement of shutdown and purge room. + + Shuts down a room by removing all local users from the room. + Blocking all future invites and joins to the room is optional. + + If desired any local aliases will be repointed to a new room + created by `new_room_user_id` and kicked users will be auto- + joined to the new room. + + If 'purge' is true, it will remove all traces of a room from the database. + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self._pagination_handler = hs.get_pagination_handler() + + async def on_DELETE( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) + + content = parse_json_object_from_request(request) + + block = content.get("block", False) + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean, if given", + Codes.BAD_JSON, + ) + + purge = content.get("purge", True) + if not isinstance(purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + force_purge = content.get("force_purge", False) + if not isinstance(force_purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'force_purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + if not RoomID.is_valid(room_id): + raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + + if not await self._store.get_room(room_id): + raise NotFoundError("Unknown room id %s" % (room_id,)) + + delete_id = self._pagination_handler.start_shutdown_and_purge_room( + room_id=room_id, + new_room_user_id=content.get("new_room_user_id"), + new_room_name=content.get("room_name"), + message=content.get("message"), + requester_user_id=requester.user.to_string(), + block=block, + purge=purge, + force_purge=force_purge, + ) + + return 200, {"delete_id": delete_id} + + +class DeleteRoomStatusByRoomIdRestServlet(RestServlet): + """Get the status of the delete room background task.""" + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)/delete_status$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._pagination_handler = hs.get_pagination_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + + await assert_requester_is_admin(self._auth, request) + + if not RoomID.is_valid(room_id): + raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + + delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id) + if delete_ids is None: + raise NotFoundError("No delete task for room_id '%s' found" % room_id) + + response = [] + for delete_id in delete_ids: + delete = self._pagination_handler.get_delete_status(delete_id) + if delete: + response += [ + { + "delete_id": delete_id, + **delete.asdict(), + } + ] + return 200, {"results": cast(JsonDict, response)} + + +class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): + """Get the status of the delete room background task.""" + + PATTERNS = admin_patterns("/rooms/delete_status/(?P[^/]+)$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._pagination_handler = hs.get_pagination_handler() + + async def on_GET( + self, request: SynapseRequest, delete_id: str + ) -> Tuple[int, JsonDict]: + + await assert_requester_is_admin(self._auth, request) + + delete_status = self._pagination_handler.get_delete_status(delete_id) + if delete_status is None: + raise NotFoundError("delete id '%s' not found" % delete_id) + + return 200, cast(JsonDict, delete_status.asdict()) + + class ListRoomRestServlet(RestServlet): """ List all rooms that are known to the homeserver. Results are returned diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 192073c520..af849bd471 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -474,3 +474,51 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): % server_and_media_id_2 ), ) + + +class PurgeHistoryTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}" + self.url_status = "/_synapse/admin/v1/purge_history_status/" + + def test_purge_history(self): + """ + Simple test of purge history API. + Test only that is is possible to call, get status 200 and purge_id. + """ + + channel = self.make_request( + "POST", + self.url, + content={"delete_local_events": True, "purge_up_to_ts": 0}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("purge_id", channel.json_body) + purge_id = channel.json_body["purge_id"] + + # get status + channel = self.make_request( + "GET", + self.url_status + purge_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("complete", channel.json_body["status"]) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 11ec54c82e..b48fc12e5f 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -23,6 +23,7 @@ from parameterized import parameterized import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes +from synapse.handlers.pagination import PaginationHandler from synapse.rest.client import directory, events, login, room from tests import unittest @@ -71,11 +72,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - json.dumps({}), + {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_room_does_not_exist(self): @@ -87,11 +88,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", url, - json.dumps({}), + {}, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_room_is_not_valid(self): @@ -103,11 +104,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", url, - json.dumps({}), + {}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -122,11 +123,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("new_room_id", channel.json_body) self.assertIn("kicked_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -141,11 +142,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "User must be our own: @not:exist.bla", channel.json_body["error"], @@ -160,11 +161,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_is_not_bool(self): @@ -176,11 +177,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_room_and_block(self): @@ -202,11 +203,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -235,11 +236,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -269,11 +270,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -344,7 +345,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -373,7 +374,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Test that room is not purged with self.assertRaises(AssertionError): @@ -390,7 +391,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -446,18 +447,617 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + url = "events?timeout=0&room_id=" + room_id + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + +class DeleteRoomV2TestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.consent.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = f"/_synapse/admin/v2/rooms/{self.room_id}" + self.url_status_by_room_id = ( + f"/_synapse/admin/v2/rooms/{self.room_id}/delete_status" + ) + self.url_status_by_delete_id = "/_synapse/admin/v2/rooms/delete_status/" + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), + ] + ) + def test_requester_is_no_admin(self, method: str, url: str): + """ + If the user is not a server admin, an error 403 is returned. + """ + + channel = self.make_request( + method, + url % self.room_id, + content={}, + access_token=self.other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), + ] + ) + def test_room_does_not_exist(self, method: str, url: str): + """ + Check that unknown rooms/server return error 404. + """ + + channel = self.make_request( + method, + url % "!unknown:test", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ] + ) + def test_room_is_not_valid(self, method: str, url: str): + """ + Check that invalid room names, return an error 400. + """ + + channel = self.make_request( + method, + url % "invalidroom", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual( + "invalidroom is not a legal room ID", + channel.json_body["error"], + ) + + def test_new_room_user_does_not_exist(self): + """ + Tests that the user ID must be from local server but it does not have to exist. + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": "@unknown:test"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + def test_new_room_user_is_not_local(self): + """ + Check that only local users can create new room to move members. + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": "@not:exist.bla"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] + "User must be our own: @not:exist.bla", + channel.json_body["error"], ) + def test_block_is_not_bool(self): + """ + If parameter `block` is not boolean, return an error + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"block": "NotBool"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_purge_is_not_bool(self): + """ + If parameter `purge` is not boolean, return an error + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"purge": "NotBool"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_delete_expired_status(self): + """Test that the task status is removed after expiration.""" + + # first task, do not purge, that we can create a second task + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"purge": False}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id1 = channel.json_body["delete_id"] + + # go ahead + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + # second task + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id2 = channel.json_body["delete_id"] + + # get status + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(2, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual("complete", channel.json_body["results"][1]["status"]) + self.assertEqual(delete_id1, channel.json_body["results"][0]["delete_id"]) + self.assertEqual(delete_id2, channel.json_body["results"][1]["delete_id"]) + + # get status after more than clearing time for first task + # second task is not cleared + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"]) + + # get status after more than clearing time for all tasks + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_delete_same_room_twice(self): + """Test that the call for delete a room at second time gives an exception.""" + + body = {"new_room_user_id": self.admin_user} + + # first call to delete room + # and do not wait for finish the task + first_channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content=body, + access_token=self.admin_user_tok, + await_result=False, + ) + + # second call to delete room + second_channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content=body, + access_token=self.admin_user_tok, + ) + + self.assertEqual( + HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body + ) + self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"]) + self.assertEqual( + f"History purge already in progress for {self.room_id}", + second_channel.json_body["error"], + ) + + # get result of first call + first_channel.await_result() + self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body) + self.assertIn("delete_id", first_channel.json_body) + + # check status after finish the task + self._test_result( + first_channel.json_body["delete_id"], + self.other_user, + expect_new_room=True, + ) + + def test_purge_room_and_block(self): + """Test to purge a room and block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": True, "purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_purge_room_and_not_block(self): + """Test to purge a room and do not block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": False, "purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + + def test_block_room_and_not_purge(self): + """Test to block a room without purging it. + Members will not be moved to a new room and will not receive a message. + The room will not be purged. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": True, "purge": False}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(self.room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 + ) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": self.admin_user}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"], + user_id=self.other_user, + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) + channel = self.make_request( + "PUT", + url.encode("ascii"), + content={"history_visibility": "world_readable"}, + access_token=self.other_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": self.admin_user}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"], + user_id=self.other_user, + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + # Assert we can no longer peek into the room + self._assert_peek(self.room_id, expect_code=403) + + def _is_blocked(self, room_id: str, expect: bool = True) -> None: + """Assert that the room is blocked or not""" + d = self.store.is_room_blocked(room_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertIsNone(self.get_success(d)) + + def _has_no_members(self, room_id: str) -> None: + """Assert there is now no longer anyone in the room""" + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def _is_member(self, room_id: str, user_id: str) -> None: + """Test that user is member of the room""" + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertIn(user_id, users_in_room) + + def _is_purged(self, room_id: str) -> None: + """Test that the following tables have been purged of all rows related to the room.""" + for table in PURGE_TABLES: + count = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") + + def _assert_peek(self, room_id: str, expect_code: int) -> None: + """Assert that the admin user can (or cannot) peek into the room.""" + + url = f"rooms/{room_id}/initialSync" + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + url = "events?timeout=0&room_id=" + room_id channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + def _test_result( + self, + delete_id: str, + kicked_user: str, + expect_new_room: bool = False, + ) -> None: + """ + Test that the result is the expected. + Uses both APIs (status by room_id and delete_id) + + Args: + delete_id: id of this purge + kicked_user: a user_id which is kicked from the room + expect_new_room: if we expect that a new room was created + """ + + # get information by room_id + channel_room_id = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] + HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body + ) + self.assertEqual(1, len(channel_room_id.json_body["results"])) + self.assertEqual( + delete_id, channel_room_id.json_body["results"][0]["delete_id"] ) + # get information by delete_id + channel_delete_id = self.make_request( + "GET", + self.url_status_by_delete_id + delete_id, + access_token=self.admin_user_tok, + ) + self.assertEqual( + HTTPStatus.OK, + channel_delete_id.code, + msg=channel_delete_id.json_body, + ) + + # test values that are the same in both responses + for content in [ + channel_room_id.json_body["results"][0], + channel_delete_id.json_body, + ]: + self.assertEqual("complete", content["status"]) + self.assertEqual(kicked_user, content["shutdown_room"]["kicked_users"][0]) + self.assertIn("failed_to_kick_users", content["shutdown_room"]) + self.assertIn("local_aliases", content["shutdown_room"]) + self.assertNotIn("error", content) + + if expect_new_room: + self.assertIsNotNone(content["shutdown_room"]["new_room_id"]) + else: + self.assertIsNone(content["shutdown_room"]["new_room_id"]) + class RoomTestCase(unittest.HomeserverTestCase): """Test /room admin API.""" @@ -494,7 +1094,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) # Check request completed successfully - self.assertEqual(200, int(channel.code), msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that response json body contains a "rooms" key self.assertTrue( @@ -578,9 +1178,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue("rooms" in channel.json_body) for r in channel.json_body["rooms"]: @@ -620,7 +1218,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_correct_room_attributes(self): """Test the correct attributes for a room are returned""" @@ -643,7 +1241,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -675,7 +1273,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1135,7 +1733,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1185,11 +1783,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.second_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self): @@ -1201,11 +1799,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) def test_local_user_does_not_exist(self): @@ -1217,11 +1815,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_remote_user(self): @@ -1233,11 +1831,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "This endpoint can only be used with local users", channel.json_body["error"], @@ -1253,11 +1851,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("No known servers", channel.json_body["error"]) def test_room_is_not_valid(self): @@ -1270,11 +1868,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom was not legal room ID or room alias", channel.json_body["error"], @@ -1289,11 +1887,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1303,7 +1901,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_not_member(self): @@ -1320,11 +1918,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_join_private_room_if_member(self): @@ -1352,7 +1950,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. @@ -1363,10 +1961,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1376,7 +1974,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_owner(self): @@ -1393,11 +1991,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1407,7 +2005,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_context_as_non_admin(self): @@ -1441,9 +2039,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEquals( - 403, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEquals(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_context_as_admin(self): @@ -1473,7 +2069,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEquals( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) @@ -1532,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) @@ -1559,7 +2155,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room (we should have received an # invite) and can ban a user. @@ -1585,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.second_user_id, tok=self.second_tok) @@ -1623,7 +2219,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): # # (Note we assert the error message to ensure that it's not denied for # some other reason) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body["error"], "No local admin user in room with power to update power levels.", -- cgit 1.5.1 From bea815cec82096efc15e75875d3b8372fcb4f28b Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 12 Nov 2021 19:56:00 +0000 Subject: Test room alias deletion (#11327) * Prefer `HTTPStatus` over plain `int` This is an Opinion that no-one has seemed to object to yet. * `--disallow-untyped-defs` for `tests.rest.client.test_directory` * Improve synapse's annotations for deleting aliases * Test case for deleting a room alias * Changelog --- changelog.d/11327.misc | 1 + mypy.ini | 3 + synapse/handlers/directory.py | 6 +- synapse/storage/databases/main/directory.py | 7 +- tests/rest/client/test_directory.py | 105 ++++++++++++++++++++-------- 5 files changed, 91 insertions(+), 31 deletions(-) create mode 100644 changelog.d/11327.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11327.misc b/changelog.d/11327.misc new file mode 100644 index 0000000000..389e360457 --- /dev/null +++ b/changelog.d/11327.misc @@ -0,0 +1 @@ +Test that room alias deletion works as intended. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 56a62bb9b7..67dde5991b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -271,6 +271,9 @@ disallow_untyped_defs = True [mypy-tests] disallow_untyped_defs = True +[mypy-tests.rest.client.test_directory] +disallow_untyped_defs = True + ;; Dependencies without annotations ;; Before ignoring a module, check to see if type stubs are available. ;; The `typeshed` project maintains stubs here: diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 8ca5f60b1c..7ee5c47fd9 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -204,6 +204,10 @@ class DirectoryHandler: ) room_id = await self._delete_association(room_alias) + if room_id is None: + # It's possible someone else deleted the association after the + # checks above, but before we did the deletion. + raise NotFoundError("Unknown room alias") try: await self._update_canonical_alias(requester, user_id, room_id, room_alias) @@ -225,7 +229,7 @@ class DirectoryHandler: ) await self._delete_association(room_alias) - async def _delete_association(self, room_alias: RoomAlias) -> str: + async def _delete_association(self, room_alias: RoomAlias) -> Optional[str]: if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 6daf8b8ffb..25131b1ea9 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -17,6 +17,7 @@ from typing import Iterable, List, Optional from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached @@ -126,14 +127,16 @@ class DirectoryWorkerStore(SQLBaseStore): class DirectoryStore(DirectoryWorkerStore): - async def delete_room_alias(self, room_alias: RoomAlias) -> str: + async def delete_room_alias(self, room_alias: RoomAlias) -> Optional[str]: room_id = await self.db_pool.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) return room_id - def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str: + def _delete_room_alias_txn( + self, txn: LoggingTransaction, room_alias: RoomAlias + ) -> Optional[str]: txn.execute( "SELECT room_id FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),), diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py index d2181ea907..aca03afd0e 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py @@ -11,12 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin from synapse.rest.client import directory, login, room +from synapse.server import HomeServer from synapse.types import RoomAlias +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest @@ -32,7 +36,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["require_membership_for_aliases"] = True @@ -40,7 +44,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + """Create two local users and access tokens for them. + One of them creates a room.""" self.room_owner = self.register_user("room_owner", "test") self.room_owner_tok = self.login("room_owner", "test") @@ -51,39 +59,39 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.user = self.register_user("user", "test") self.user_tok = self.login("user", "test") - def test_state_event_not_in_room(self): + def test_state_event_not_in_room(self) -> None: self.ensure_user_left_room() - self.set_alias_via_state_event(403) + self.set_alias_via_state_event(HTTPStatus.FORBIDDEN) - def test_directory_endpoint_not_in_room(self): + def test_directory_endpoint_not_in_room(self) -> None: self.ensure_user_left_room() - self.set_alias_via_directory(403) + self.set_alias_via_directory(HTTPStatus.FORBIDDEN) - def test_state_event_in_room_too_long(self): + def test_state_event_in_room_too_long(self) -> None: self.ensure_user_joined_room() - self.set_alias_via_state_event(400, alias_length=256) + self.set_alias_via_state_event(HTTPStatus.BAD_REQUEST, alias_length=256) - def test_directory_in_room_too_long(self): + def test_directory_in_room_too_long(self) -> None: self.ensure_user_joined_room() - self.set_alias_via_directory(400, alias_length=256) + self.set_alias_via_directory(HTTPStatus.BAD_REQUEST, alias_length=256) @override_config({"default_room_version": 5}) - def test_state_event_user_in_v5_room(self): + def test_state_event_user_in_v5_room(self) -> None: """Test that a regular user can add alias events before room v6""" self.ensure_user_joined_room() - self.set_alias_via_state_event(200) + self.set_alias_via_state_event(HTTPStatus.OK) @override_config({"default_room_version": 6}) - def test_state_event_v6_room(self): + def test_state_event_v6_room(self) -> None: """Test that a regular user can *not* add alias events from room v6""" self.ensure_user_joined_room() - self.set_alias_via_state_event(403) + self.set_alias_via_state_event(HTTPStatus.FORBIDDEN) - def test_directory_in_room(self): + def test_directory_in_room(self) -> None: self.ensure_user_joined_room() - self.set_alias_via_directory(200) + self.set_alias_via_directory(HTTPStatus.OK) - def test_room_creation_too_long(self): + def test_room_creation_too_long(self) -> None: url = "/_matrix/client/r0/createRoom" # We use deliberately a localpart under the length threshold so @@ -93,9 +101,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - def test_room_creation(self): + def test_room_creation(self) -> None: url = "/_matrix/client/r0/createRoom" # Check with an alias of allowed length. There should already be @@ -106,9 +114,46 @@ class DirectoryTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + def test_deleting_alias_via_directory(self) -> None: + # Add an alias for the room. We must be joined to do so. + self.ensure_user_joined_room() + alias = self.set_alias_via_directory(HTTPStatus.OK) + + # Then try to remove the alias + channel = self.make_request( + "DELETE", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + def test_deleting_nonexistant_alias(self) -> None: + # Check that no alias exists + alias = "#potato:test" + channel = self.make_request( + "GET", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) + self.assertIn("error", channel.json_body, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body) + + # Then try to remove the alias + channel = self.make_request( + "DELETE", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) + self.assertIn("error", channel.json_body, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body) - def set_alias_via_state_event(self, expected_code, alias_length=5): + def set_alias_via_state_event( + self, expected_code: HTTPStatus, alias_length: int = 5 + ) -> None: url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( self.room_id, self.hs.hostname, @@ -122,8 +167,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, expected_code, channel.result) - def set_alias_via_directory(self, expected_code, alias_length=5): - url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) + def set_alias_via_directory( + self, expected_code: HTTPStatus, alias_length: int = 5 + ) -> str: + alias = self.random_alias(alias_length) + url = "/_matrix/client/r0/directory/room/%s" % alias data = {"room_id": self.room_id} request_data = json.dumps(data) @@ -131,17 +179,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase): "PUT", url, request_data, access_token=self.user_tok ) self.assertEqual(channel.code, expected_code, channel.result) + return alias - def random_alias(self, length): + def random_alias(self, length: int) -> str: return RoomAlias(random_string(length), self.hs.hostname).to_string() - def ensure_user_left_room(self): + def ensure_user_left_room(self) -> None: self.ensure_membership("leave") - def ensure_user_joined_room(self): + def ensure_user_joined_room(self) -> None: self.ensure_membership("join") - def ensure_membership(self, membership): + def ensure_membership(self, membership: str) -> None: try: if membership == "leave": self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) -- cgit 1.5.1 From 3a1462f7e0693f2fb5f3654a8647552ea46dd67a Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 16 Nov 2021 12:53:31 +0000 Subject: Properly register all callback hooks for legacy password authentication providers (#11340) Co-authored-by: Patrick Cloke --- changelog.d/11340.bugfix | 1 + synapse/handlers/auth.py | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 13 deletions(-) create mode 100644 changelog.d/11340.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11340.bugfix b/changelog.d/11340.bugfix new file mode 100644 index 0000000000..551817f42d --- /dev/null +++ b/changelog.d/11340.bugfix @@ -0,0 +1 @@ +Fix a bug, introduced in Synapse 1.46.0, which caused the `check_3pid_auth` and `on_logged_out` callbacks in legacy password authentication provider modules to not be registered. Modules using the generic module API were not affected. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 60e59d11a0..b62e13b725 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1828,13 +1828,6 @@ def load_single_legacy_password_auth_provider( logger.error("Error while initializing %r: %s", module, e) raise - # The known hooks. If a module implements a method who's name appears in this set - # we'll want to register it - password_auth_provider_methods = { - "check_3pid_auth", - "on_logged_out", - } - # All methods that the module provides should be async, but this wasn't enforced # in the old module system, so we wrap them if needed def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: @@ -1919,11 +1912,14 @@ def load_single_legacy_password_auth_provider( return run - # populate hooks with the implemented methods, wrapped with async_wrapper - hooks = { - hook: async_wrapper(getattr(provider, hook, None)) - for hook in password_auth_provider_methods - } + # If the module has these methods implemented, then we pull them out + # and register them as hooks. + check_3pid_auth_hook: Optional[CHECK_3PID_AUTH_CALLBACK] = async_wrapper( + getattr(provider, "check_3pid_auth", None) + ) + on_logged_out_hook: Optional[ON_LOGGED_OUT_CALLBACK] = async_wrapper( + getattr(provider, "on_logged_out", None) + ) supported_login_types = {} # call get_supported_login_types and add that to the dict @@ -1950,7 +1946,11 @@ def load_single_legacy_password_auth_provider( # need to use a tuple here for ("password",) not a list since lists aren't hashable auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password - api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers) + api.register_password_auth_provider_callbacks( + check_3pid_auth=check_3pid_auth_hook, + on_logged_out=on_logged_out_hook, + auth_checkers=auth_checkers, + ) CHECK_3PID_AUTH_CALLBACK = Callable[ -- cgit 1.5.1 From 88375beeaab3f3902d6f22fea6a14daa8f2b8a5a Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 16 Nov 2021 15:40:47 +0000 Subject: Avoid sharing room hierarchy responses between users (#11355) Different users may be allowed to see different rooms within a space, so sharing responses between users is inadvisable. --- changelog.d/11355.bugfix | 1 + synapse/handlers/room_summary.py | 11 ++++++-- tests/handlers/test_room_summary.py | 55 +++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11355.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11355.bugfix b/changelog.d/11355.bugfix new file mode 100644 index 0000000000..91639f14b2 --- /dev/null +++ b/changelog.d/11355.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.41.0 where space hierarchy responses would be incorrectly reused if multiple users were to make the same request at the same time. diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index fb26ee7ad7..8181cc0b52 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -97,7 +97,7 @@ class RoomSummaryHandler: # If a user tries to fetch the same page multiple times in quick succession, # only process the first attempt and return its result to subsequent requests. self._pagination_response_cache: ResponseCache[ - Tuple[str, bool, Optional[int], Optional[int], Optional[str]] + Tuple[str, str, bool, Optional[int], Optional[int], Optional[str]] ] = ResponseCache( hs.get_clock(), "get_room_hierarchy", @@ -282,7 +282,14 @@ class RoomSummaryHandler: # This is due to the pagination process mutating internal state, attempting # to process multiple requests for the same page will result in errors. return await self._pagination_response_cache.wrap( - (requested_room_id, suggested_only, max_depth, limit, from_token), + ( + requester, + requested_room_id, + suggested_only, + max_depth, + limit, + from_token, + ), self._get_room_hierarchy, requester, requested_room_id, diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index d3d0bf1ac5..7b95844b55 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -14,6 +14,8 @@ from typing import Any, Iterable, List, Optional, Tuple from unittest import mock +from twisted.internet.defer import ensureDeferred + from synapse.api.constants import ( EventContentFields, EventTypes, @@ -316,6 +318,59 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): AuthError, ) + def test_room_hierarchy_cache(self) -> None: + """In-flight room hierarchy requests are deduplicated.""" + # Run two `get_room_hierarchy` calls up until they block. + deferred1 = ensureDeferred( + self.handler.get_room_hierarchy(self.user, self.space) + ) + deferred2 = ensureDeferred( + self.handler.get_room_hierarchy(self.user, self.space) + ) + + # Complete the two calls. + result1 = self.get_success(deferred1) + result2 = self.get_success(deferred2) + + # Both `get_room_hierarchy` calls should return the same result. + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_hierarchy(result1, expected) + self._assert_hierarchy(result2, expected) + self.assertIs(result1, result2) + + # A subsequent `get_room_hierarchy` call should not reuse the result. + result3 = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result3, expected) + self.assertIsNot(result1, result3) + + def test_room_hierarchy_cache_sharing(self) -> None: + """Room hierarchy responses for different users are not shared.""" + user2 = self.register_user("user2", "pass") + + # Make the room within the space invite-only. + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + + # Run two `get_room_hierarchy` calls for different users up until they block. + deferred1 = ensureDeferred( + self.handler.get_room_hierarchy(self.user, self.space) + ) + deferred2 = ensureDeferred(self.handler.get_room_hierarchy(user2, self.space)) + + # Complete the two calls. + result1 = self.get_success(deferred1) + result2 = self.get_success(deferred2) + + # The `get_room_hierarchy` calls should return different results. + self._assert_hierarchy(result1, [(self.space, [self.room]), (self.room, ())]) + self._assert_hierarchy(result2, [(self.space, [self.room])]) + def _create_room_with_join_rule( self, join_rule: str, room_version: Optional[str] = None, **extra_content ) -> str: -- cgit 1.5.1 From 0d86f6334ae5dc6be76b5bf2eea9ac5677fb89e8 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 17 Nov 2021 14:10:57 +0000 Subject: Rename `get_access_token_for_user_id` method to `create_access_token_for_user_id` (#11369) --- changelog.d/11369.misc | 1 + synapse/handlers/auth.py | 4 ++-- synapse/handlers/register.py | 2 +- synapse/rest/admin/users.py | 2 +- tests/handlers/test_auth.py | 10 +++++----- tests/rest/admin/test_user.py | 4 ++-- tests/rest/client/test_capabilities.py | 8 ++++---- 7 files changed, 16 insertions(+), 15 deletions(-) create mode 100644 changelog.d/11369.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11369.misc b/changelog.d/11369.misc new file mode 100644 index 0000000000..3c1dad544b --- /dev/null +++ b/changelog.d/11369.misc @@ -0,0 +1 @@ +Rename `get_access_token_for_user_id` to `create_access_token_for_user_id` to better reflect what it does. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b62e13b725..b2c84a0fce 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -793,7 +793,7 @@ class AuthHandler: ) = await self.get_refresh_token_for_user_id( user_id=existing_token.user_id, device_id=existing_token.device_id ) - access_token = await self.get_access_token_for_user_id( + access_token = await self.create_access_token_for_user_id( user_id=existing_token.user_id, device_id=existing_token.device_id, valid_until_ms=valid_until_ms, @@ -855,7 +855,7 @@ class AuthHandler: ) return refresh_token, refresh_token_id - async def get_access_token_for_user_id( + async def create_access_token_for_user_id( self, user_id: str, device_id: Optional[str], diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a0e6a01775..6b5c3d1974 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -819,7 +819,7 @@ class RegistrationHandler: ) valid_until_ms = self.clock.time_msec() + self.access_token_lifetime - access_token = await self._auth_handler.get_access_token_for_user_id( + access_token = await self._auth_handler.create_access_token_for_user_id( user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms, diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 23a8bf1fdb..ccd9a2a175 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -898,7 +898,7 @@ class UserTokenRestServlet(RestServlet): if auth_user.to_string() == user_id: raise SynapseError(400, "Cannot use admin API to login as self") - token = await self.auth_handler.get_access_token_for_user_id( + token = await self.auth_handler.create_access_token_for_user_id( user_id=auth_user.to_string(), device_id=None, valid_until_ms=valid_until_ms, diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 12857053e7..72e176da75 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -116,7 +116,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ) ) @@ -134,7 +134,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) self.get_failure( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ), ResourceLimitError, @@ -162,7 +162,7 @@ class AuthTestCase(unittest.HomeserverTestCase): # If not in monthly active cohort self.get_failure( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ), ResourceLimitError, @@ -179,7 +179,7 @@ class AuthTestCase(unittest.HomeserverTestCase): return_value=make_awaitable(self.clock.time_msec()) ) self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ) ) @@ -197,7 +197,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) # Ensure does not raise exception self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ) ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index c9fe0f06c2..5011e54563 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1169,14 +1169,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): # regardless of whether password login or SSO is allowed self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.admin_user, device_id=None, valid_until_ms=None ) ) self.other_user = self.register_user("user", "pass", displayname="User") self.other_user_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.other_user, device_id=None, valid_until_ms=None ) ) diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index b9e3602552..249808b031 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -71,7 +71,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"localdb_enabled": False}}) def test_get_change_password_capabilities_localdb_disabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) @@ -85,7 +85,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"enabled": False}}) def test_get_change_password_capabilities_password_disabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) @@ -174,7 +174,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"experimental_features": {"msc3244_enabled": False}}) def test_get_does_not_include_msc3244_fields_when_disabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) @@ -189,7 +189,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): def test_get_does_include_msc3244_fields_when_enabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) -- cgit 1.5.1 From 84fac0f814f69645ff1ad564ef8294b31203dc95 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 17 Nov 2021 19:07:02 +0000 Subject: Add type annotations to `synapse.metrics` (#10847) --- changelog.d/10847.misc | 1 + mypy.ini | 3 + synapse/app/_base.py | 2 +- synapse/groups/attestations.py | 4 +- synapse/handlers/typing.py | 2 +- synapse/metrics/__init__.py | 101 ++++++++++++++++++-------- synapse/metrics/_exposition.py | 34 ++++----- synapse/metrics/background_process_metrics.py | 78 +++++++++++++++----- synapse/metrics/jemalloc.py | 10 ++- synapse/storage/database.py | 6 +- synapse/util/caches/expiringcache.py | 2 +- synapse/util/metrics.py | 15 ++-- 12 files changed, 173 insertions(+), 85 deletions(-) create mode 100644 changelog.d/10847.misc (limited to 'synapse/handlers') diff --git a/changelog.d/10847.misc b/changelog.d/10847.misc new file mode 100644 index 0000000000..7933a38dca --- /dev/null +++ b/changelog.d/10847.misc @@ -0,0 +1 @@ +Add type annotations to `synapse.metrics`. diff --git a/mypy.ini b/mypy.ini index f32c6c41a3..308cfd95d8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -160,6 +160,9 @@ disallow_untyped_defs = True [mypy-synapse.handlers.*] disallow_untyped_defs = True +[mypy-synapse.metrics.*] +disallow_untyped_defs = True + [mypy-synapse.push.*] disallow_untyped_defs = True diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 573bb487b2..807ee3d46e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -402,7 +402,7 @@ async def start(hs: "HomeServer") -> None: if hasattr(signal, "SIGHUP"): @wrap_as_background_process("sighup") - def handle_sighup(*args: Any, **kwargs: Any) -> None: + async def handle_sighup(*args: Any, **kwargs: Any) -> None: # Tell systemd our state, if we're using it. This will silently fail if # we're not using systemd. sdnotify(b"RELOADING=1") diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index 53f99031b1..a87896e538 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -40,6 +40,8 @@ from typing import TYPE_CHECKING, Optional, Tuple from signedjson.sign import sign_json +from twisted.internet.defer import Deferred + from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict, get_domain_from_id @@ -166,7 +168,7 @@ class GroupAttestionRenewer: return {} - def _start_renew_attestations(self) -> None: + def _start_renew_attestations(self) -> "Deferred[None]": return run_as_background_process("renew_attestations", self._renew_attestations) async def _renew_attestations(self) -> None: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 22c6174821..1676ebd057 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -90,7 +90,7 @@ class FollowerTypingHandler: self.wheel_timer = WheelTimer(bucket_size=5000) @wrap_as_background_process("typing._handle_timeouts") - def _handle_timeouts(self) -> None: + async def _handle_timeouts(self) -> None: logger.debug("Checking for typing timeouts") now = self.clock.time_msec() diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 91ee5c8193..ceef57ad88 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -20,10 +20,25 @@ import os import platform import threading import time -from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) import attr -from prometheus_client import Counter, Gauge, Histogram +from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric from prometheus_client.core import ( REGISTRY, CounterMetricFamily, @@ -32,6 +47,7 @@ from prometheus_client.core import ( ) from twisted.internet import reactor +from twisted.internet.base import ReactorBase from twisted.python.threadpool import ThreadPool import synapse @@ -54,7 +70,7 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") class RegistryProxy: @staticmethod - def collect(): + def collect() -> Iterable[Metric]: for metric in REGISTRY.collect(): if not metric.name.startswith("__"): yield metric @@ -74,7 +90,7 @@ class LaterGauge: ] ) - def collect(self): + def collect(self) -> Iterable[Metric]: g = GaugeMetricFamily(self.name, self.desc, labels=self.labels) @@ -93,10 +109,10 @@ class LaterGauge: yield g - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._register() - def _register(self): + def _register(self) -> None: if self.name in all_gauges.keys(): logger.warning("%s already registered, reregistering" % (self.name,)) REGISTRY.unregister(all_gauges.pop(self.name)) @@ -105,7 +121,12 @@ class LaterGauge: all_gauges[self.name] = self -class InFlightGauge: +# `MetricsEntry` only makes sense when it is a `Protocol`, +# but `Protocol` can't be used as a `TypeVar` bound. +MetricsEntry = TypeVar("MetricsEntry") + + +class InFlightGauge(Generic[MetricsEntry]): """Tracks number of things (e.g. requests, Measure blocks, etc) in flight at any given time. @@ -115,14 +136,19 @@ class InFlightGauge: callbacks. Args: - name (str) - desc (str) - labels (list[str]) - sub_metrics (list[str]): A list of sub metrics that the callbacks - will update. + name + desc + labels + sub_metrics: A list of sub metrics that the callbacks will update. """ - def __init__(self, name, desc, labels, sub_metrics): + def __init__( + self, + name: str, + desc: str, + labels: Sequence[str], + sub_metrics: Sequence[str], + ): self.name = name self.desc = desc self.labels = labels @@ -130,19 +156,25 @@ class InFlightGauge: # Create a class which have the sub_metrics values as attributes, which # default to 0 on initialization. Used to pass to registered callbacks. - self._metrics_class = attr.make_class( + self._metrics_class: Type[MetricsEntry] = attr.make_class( "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True ) # Counts number of in flight blocks for a given set of label values - self._registrations: Dict = {} + self._registrations: Dict[ + Tuple[str, ...], Set[Callable[[MetricsEntry], None]] + ] = {} # Protects access to _registrations self._lock = threading.Lock() self._register_with_collector() - def register(self, key, callback): + def register( + self, + key: Tuple[str, ...], + callback: Callable[[MetricsEntry], None], + ) -> None: """Registers that we've entered a new block with labels `key`. `callback` gets called each time the metrics are collected. The same @@ -158,13 +190,17 @@ class InFlightGauge: with self._lock: self._registrations.setdefault(key, set()).add(callback) - def unregister(self, key, callback): + def unregister( + self, + key: Tuple[str, ...], + callback: Callable[[MetricsEntry], None], + ) -> None: """Registers that we've exited a block with labels `key`.""" with self._lock: self._registrations.setdefault(key, set()).discard(callback) - def collect(self): + def collect(self) -> Iterable[Metric]: """Called by prometheus client when it reads metrics. Note: may be called by a separate thread. @@ -200,7 +236,7 @@ class InFlightGauge: gauge.add_metric(key, getattr(metrics, name)) yield gauge - def _register_with_collector(self): + def _register_with_collector(self) -> None: if self.name in all_gauges.keys(): logger.warning("%s already registered, reregistering" % (self.name,)) REGISTRY.unregister(all_gauges.pop(self.name)) @@ -230,7 +266,7 @@ class GaugeBucketCollector: name: str, documentation: str, buckets: Iterable[float], - registry=REGISTRY, + registry: CollectorRegistry = REGISTRY, ): """ Args: @@ -257,12 +293,12 @@ class GaugeBucketCollector: registry.register(self) - def collect(self): + def collect(self) -> Iterable[Metric]: # Don't report metrics unless we've already collected some data if self._metric is not None: yield self._metric - def update_data(self, values: Iterable[float]): + def update_data(self, values: Iterable[float]) -> None: """Update the data to be reported by the metric The existing data is cleared, and each measurement in the input is assigned @@ -304,7 +340,7 @@ class GaugeBucketCollector: class CPUMetrics: - def __init__(self): + def __init__(self) -> None: ticks_per_sec = 100 try: # Try and get the system config @@ -314,7 +350,7 @@ class CPUMetrics: self.ticks_per_sec = ticks_per_sec - def collect(self): + def collect(self) -> Iterable[Metric]: if not HAVE_PROC_SELF_STAT: return @@ -364,7 +400,7 @@ gc_time = Histogram( class GCCounts: - def collect(self): + def collect(self) -> Iterable[Metric]: cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) for n, m in enumerate(gc.get_count()): cm.add_metric([str(n)], m) @@ -382,7 +418,7 @@ if not running_on_pypy: class PyPyGCStats: - def collect(self): + def collect(self) -> Iterable[Metric]: # @stats is a pretty-printer object with __str__() returning a nice table, # plus some fields that contain data from that table. @@ -565,7 +601,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None: class ReactorLastSeenMetric: - def collect(self): + def collect(self) -> Iterable[Metric]: cm = GaugeMetricFamily( "python_twisted_reactor_last_seen", "Seconds since the Twisted reactor was last seen", @@ -584,9 +620,12 @@ MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0) _last_gc = [0.0, 0.0, 0.0] -def runUntilCurrentTimer(reactor, func): +F = TypeVar("F", bound=Callable[..., Any]) + + +def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F: @functools.wraps(func) - def f(*args, **kwargs): + def f(*args: Any, **kwargs: Any) -> Any: now = reactor.seconds() num_pending = 0 @@ -649,7 +688,7 @@ def runUntilCurrentTimer(reactor, func): return ret - return f + return cast(F, f) try: @@ -677,5 +716,5 @@ __all__ = [ "start_http_server", "LaterGauge", "InFlightGauge", - "BucketCollector", + "GaugeBucketCollector", ] diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index bb9bcb5592..353d0a63b6 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -25,27 +25,25 @@ import math import threading from http.server import BaseHTTPRequestHandler, HTTPServer from socketserver import ThreadingMixIn -from typing import Dict, List +from typing import Any, Dict, List, Type, Union from urllib.parse import parse_qs, urlparse -from prometheus_client import REGISTRY +from prometheus_client import REGISTRY, CollectorRegistry +from prometheus_client.core import Sample from twisted.web.resource import Resource +from twisted.web.server import Request from synapse.util import caches CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" -INF = float("inf") -MINUS_INF = float("-inf") - - -def floatToGoString(d): +def floatToGoString(d: Union[int, float]) -> str: d = float(d) - if d == INF: + if d == math.inf: return "+Inf" - elif d == MINUS_INF: + elif d == -math.inf: return "-Inf" elif math.isnan(d): return "NaN" @@ -60,7 +58,7 @@ def floatToGoString(d): return s -def sample_line(line, name): +def sample_line(line: Sample, name: str) -> str: if line.labels: labelstr = "{{{0}}}".format( ",".join( @@ -82,7 +80,7 @@ def sample_line(line, name): return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) -def generate_latest(registry, emit_help=False): +def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes: # Trigger the cache metrics to be rescraped, which updates the common # metrics but do not produce metrics themselves @@ -187,7 +185,7 @@ class MetricsHandler(BaseHTTPRequestHandler): registry = REGISTRY - def do_GET(self): + def do_GET(self) -> None: registry = self.registry params = parse_qs(urlparse(self.path).query) @@ -207,11 +205,11 @@ class MetricsHandler(BaseHTTPRequestHandler): self.end_headers() self.wfile.write(output) - def log_message(self, format, *args): + def log_message(self, format: str, *args: Any) -> None: """Log nothing.""" @classmethod - def factory(cls, registry): + def factory(cls, registry: CollectorRegistry) -> Type: """Returns a dynamic MetricsHandler class tied to the passed registry. """ @@ -236,7 +234,9 @@ class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer): daemon_threads = True -def start_http_server(port, addr="", registry=REGISTRY): +def start_http_server( + port: int, addr: str = "", registry: CollectorRegistry = REGISTRY +) -> None: """Starts an HTTP server for prometheus metrics as a daemon thread""" CustomMetricsHandler = MetricsHandler.factory(registry) httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler) @@ -252,10 +252,10 @@ class MetricsResource(Resource): isLeaf = True - def __init__(self, registry=REGISTRY): + def __init__(self, registry: CollectorRegistry = REGISTRY): self.registry = registry - def render_GET(self, request): + def render_GET(self, request: Request) -> bytes: request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) response = generate_latest(self.registry) request.setHeader(b"Content-Length", str(len(response))) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 2ab599a334..53c508af91 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -15,19 +15,37 @@ import logging import threading from functools import wraps -from typing import TYPE_CHECKING, Dict, Optional, Set, Union +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + Optional, + Set, + Type, + TypeVar, + Union, + cast, +) +from prometheus_client import Metric from prometheus_client.core import REGISTRY, Counter, Gauge from twisted.internet import defer -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import ( + ContextResourceUsage, + LoggingContext, + PreserveLoggingContext, +) from synapse.logging.opentracing import ( SynapseTags, noop_context_manager, start_active_span, ) -from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: import resource @@ -116,7 +134,7 @@ class _Collector: before they are returned. """ - def collect(self): + def collect(self) -> Iterable[Metric]: global _background_processes_active_since_last_scrape # We swap out the _background_processes set with an empty one so that @@ -144,12 +162,12 @@ REGISTRY.register(_Collector()) class _BackgroundProcess: - def __init__(self, desc, ctx): + def __init__(self, desc: str, ctx: LoggingContext): self.desc = desc self._context = ctx - self._reported_stats = None + self._reported_stats: Optional[ContextResourceUsage] = None - def update_metrics(self): + def update_metrics(self) -> None: """Updates the metrics with values from this process.""" new_stats = self._context.get_resource_usage() if self._reported_stats is None: @@ -169,7 +187,16 @@ class _BackgroundProcess: ) -def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs): +R = TypeVar("R") + + +def run_as_background_process( + desc: str, + func: Callable[..., Awaitable[Optional[R]]], + *args: Any, + bg_start_span: bool = True, + **kwargs: Any, +) -> "defer.Deferred[Optional[R]]": """Run the given function in its own logcontext, with resource metrics This should be used to wrap processes which are fired off to run in the @@ -189,11 +216,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar args: positional args for func kwargs: keyword args for func - Returns: Deferred which returns the result of func, but note that it does not - follow the synapse logcontext rules. + Returns: + Deferred which returns the result of func, or `None` if func raises. + Note that the returned Deferred does not follow the synapse logcontext + rules. """ - async def run(): + async def run() -> Optional[R]: with _bg_metrics_lock: count = _background_process_counts.get(desc, 0) _background_process_counts[desc] = count + 1 @@ -210,12 +239,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar else: ctx = noop_context_manager() with ctx: - return await maybe_awaitable(func(*args, **kwargs)) + return await func(*args, **kwargs) except Exception: logger.exception( "Background process '%s' threw an exception", desc, ) + return None finally: _background_process_in_flight_count.labels(desc).dec() @@ -225,19 +255,24 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar return defer.ensureDeferred(run()) -def wrap_as_background_process(desc): +F = TypeVar("F", bound=Callable[..., Awaitable[Optional[Any]]]) + + +def wrap_as_background_process(desc: str) -> Callable[[F], F]: """Decorator that wraps a function that gets called as a background process. - Equivalent of calling the function with `run_as_background_process` + Equivalent to calling the function with `run_as_background_process` """ - def wrap_as_background_process_inner(func): + def wrap_as_background_process_inner(func: F) -> F: @wraps(func) - def wrap_as_background_process_inner_2(*args, **kwargs): + def wrap_as_background_process_inner_2( + *args: Any, **kwargs: Any + ) -> "defer.Deferred[Optional[R]]": return run_as_background_process(desc, func, *args, **kwargs) - return wrap_as_background_process_inner_2 + return cast(F, wrap_as_background_process_inner_2) return wrap_as_background_process_inner @@ -265,7 +300,7 @@ class BackgroundProcessLoggingContext(LoggingContext): super().__init__("%s-%s" % (name, instance_id)) self._proc = _BackgroundProcess(name, self) - def start(self, rusage: "Optional[resource.struct_rusage]"): + def start(self, rusage: "Optional[resource.struct_rusage]") -> None: """Log context has started running (again).""" super().start(rusage) @@ -276,7 +311,12 @@ class BackgroundProcessLoggingContext(LoggingContext): with _bg_metrics_lock: _background_processes_active_since_last_scrape.add(self._proc) - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: """Log context has finished.""" super().__exit__(type, value, traceback) diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py index 29ab6c0229..98ed9c0829 100644 --- a/synapse/metrics/jemalloc.py +++ b/synapse/metrics/jemalloc.py @@ -16,14 +16,16 @@ import ctypes import logging import os import re -from typing import Optional +from typing import Iterable, Optional + +from prometheus_client import Metric from synapse.metrics import REGISTRY, GaugeMetricFamily logger = logging.getLogger(__name__) -def _setup_jemalloc_stats(): +def _setup_jemalloc_stats() -> None: """Checks to see if jemalloc is loaded, and hooks up a collector to record statistics exposed by jemalloc. """ @@ -135,7 +137,7 @@ def _setup_jemalloc_stats(): class JemallocCollector: """Metrics for internal jemalloc stats.""" - def collect(self): + def collect(self) -> Iterable[Metric]: _jemalloc_refresh_stats() g = GaugeMetricFamily( @@ -185,7 +187,7 @@ def _setup_jemalloc_stats(): logger.debug("Added jemalloc stats") -def setup_jemalloc_stats(): +def setup_jemalloc_stats() -> None: """Try to setup jemalloc stats, if jemalloc is loaded.""" try: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index d4cab69ebf..0693d39006 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -188,7 +188,7 @@ class LoggingDatabaseConnection: # The type of entry which goes on our after_callbacks and exception_callbacks lists. -_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]] +_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]] R = TypeVar("R") @@ -235,7 +235,7 @@ class LoggingTransaction: self.after_callbacks = after_callbacks self.exception_callbacks = exception_callbacks - def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any): + def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. @@ -247,7 +247,7 @@ class LoggingTransaction: self.after_callbacks.append((callback, args, kwargs)) def call_on_exception( - self, callback: Callable[..., None], *args: Any, **kwargs: Any + self, callback: Callable[..., object], *args: Any, **kwargs: Any ): # if self.exception_callbacks is None, that means that whatever constructed the # LoggingTransaction isn't expecting there to be any callbacks; assert that diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 6a7e534576..67ee4c693b 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -159,7 +159,7 @@ class ExpiringCache(Generic[KT, VT]): self[key] = value return value - def _prune_cache(self) -> None: + async def _prune_cache(self) -> None: if not self._expiry_ms: # zero expiry time means don't expire. This should never get called # since we have this check in start too. diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index ad775dfc7d..98ee49af6e 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -56,8 +56,15 @@ block_db_sched_duration = Counter( "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"] ) + +# This is dynamically created in InFlightGauge.__init__. +class _InFlightMetric(Protocol): + real_time_max: float + real_time_sum: float + + # Tracks the number of blocks currently active -in_flight = InFlightGauge( +in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge( "synapse_util_metrics_block_in_flight", "", labels=["block_name"], @@ -65,12 +72,6 @@ in_flight = InFlightGauge( ) -# This is dynamically created in InFlightGauge.__init__. -class _InFlightMetric(Protocol): - real_time_max: float - real_time_sum: float - - T = TypeVar("T", bound=Callable[..., Any]) -- cgit 1.5.1 From 4bd54b263ef7e2ac29acdc85e0c6392684c44281 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 18 Nov 2021 08:43:09 -0500 Subject: Do not allow MSC3440 threads to fork threads (#11161) Adds validation to the Client-Server API to ensure that the potential thread head does not relate to another event already. This results in not allowing a thread to "fork" into other threads. If the target event is unknown for some reason (maybe it isn't visible to your homeserver), but is the target of other events it is assumed that the thread can be created from it. Otherwise, it is rejected as an unknown event. --- changelog.d/11161.feature | 1 + synapse/handlers/message.py | 54 ++++++++++++++++++++--- synapse/storage/databases/main/relations.py | 67 ++++++++++++++++++++++++++++- tests/rest/client/test_relations.py | 62 ++++++++++++++++++++++++++ 4 files changed, 176 insertions(+), 8 deletions(-) create mode 100644 changelog.d/11161.feature (limited to 'synapse/handlers') diff --git a/changelog.d/11161.feature b/changelog.d/11161.feature new file mode 100644 index 0000000000..76b0d28084 --- /dev/null +++ b/changelog.d/11161.feature @@ -0,0 +1 @@ +Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d4c2a6ab7a..22dd4cf5fd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1001,13 +1001,52 @@ class EventCreationHandler: ) self.validator.validate_new(event, self.config) + await self._validate_event_relation(event) + logger.debug("Created event %s", event.event_id) + + return event, context + + async def _validate_event_relation(self, event: EventBase) -> None: + """ + Ensure the relation data on a new event is not bogus. + + Args: + event: The event being created. + + Raises: + SynapseError if the event is invalid. + """ + + relation = event.content.get("m.relates_to") + if not relation: + return + + relation_type = relation.get("rel_type") + if not relation_type: + return + + # Ensure the parent is real. + relates_to = relation.get("event_id") + if not relates_to: + return + + parent_event = await self.store.get_event(relates_to, allow_none=True) + if parent_event: + # And in the same room. + if parent_event.room_id != event.room_id: + raise SynapseError(400, "Relations must be in the same room") + + else: + # There must be some reason that the client knows the event exists, + # see if there are existing relations. If so, assume everything is fine. + if not await self.store.event_is_target_of_relation(relates_to): + # Otherwise, the client can't know about the parent event! + raise SynapseError(400, "Can't send relation to unknown event") # If this event is an annotation then we check that that the sender # can't annotate the same way twice (e.g. stops users from liking an # event multiple times). - relation = event.content.get("m.relates_to", {}) - if relation.get("rel_type") == RelationTypes.ANNOTATION: - relates_to = relation["event_id"] + if relation_type == RelationTypes.ANNOTATION: aggregation_key = relation["key"] already_exists = await self.store.has_user_annotated_event( @@ -1016,9 +1055,12 @@ class EventCreationHandler: if already_exists: raise SynapseError(400, "Can't send same reaction twice") - logger.debug("Created event %s", event.event_id) - - return event, context + # Don't attempt to start a thread if the parent event is a relation. + elif relation_type == RelationTypes.THREAD: + if await self.store.event_includes_relation(relates_to): + raise SynapseError( + 400, "Cannot start threads from an event with a relation" + ) @measure_func("handle_new_client_event") async def handle_new_client_event( diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 907af10995..0a43acda07 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -132,6 +132,69 @@ class RelationsWorkerStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + async def event_includes_relation(self, event_id: str) -> bool: + """Check if the given event relates to another event. + + An event has a relation if it has a valid m.relates_to with a rel_type + and event_id in the content: + + { + "content": { + "m.relates_to": { + "rel_type": "m.replace", + "event_id": "$other_event_id" + } + } + } + + Args: + event_id: The event to check. + + Returns: + True if the event includes a valid relation. + """ + + result = await self.db_pool.simple_select_one_onecol( + table="event_relations", + keyvalues={"event_id": event_id}, + retcol="event_id", + allow_none=True, + desc="event_includes_relation", + ) + return result is not None + + async def event_is_target_of_relation(self, parent_id: str) -> bool: + """Check if the given event is the target of another event's relation. + + An event is the target of an event relation if it has a valid + m.relates_to with a rel_type and event_id pointing to parent_id in the + content: + + { + "content": { + "m.relates_to": { + "rel_type": "m.replace", + "event_id": "$parent_id" + } + } + } + + Args: + parent_id: The event to check. + + Returns: + True if the event is the target of another event's relation. + """ + + result = await self.db_pool.simple_select_one_onecol( + table="event_relations", + keyvalues={"relates_to_id": parent_id}, + retcol="event_id", + allow_none=True, + desc="event_is_target_of_relation", + ) + return result is not None + @cached(tree=True) async def get_aggregation_groups_for_event( self, @@ -362,7 +425,7 @@ class RelationsWorkerStore(SQLBaseStore): %s; """ - def _get_if_event_has_relations(txn) -> List[str]: + def _get_if_events_have_relations(txn) -> List[str]: clauses: List[str] = [] clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", parent_ids @@ -387,7 +450,7 @@ class RelationsWorkerStore(SQLBaseStore): return [row[0] for row in txn] return await self.db_pool.runInteraction( - "get_if_event_has_relations", _get_if_event_has_relations + "get_if_events_have_relations", _get_if_events_have_relations ) async def has_user_annotated_event( diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 78c2fb86b9..b8a1b92a89 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -91,6 +91,49 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) self.assertEquals(400, channel.code, channel.json_body) + def test_deny_invalid_event(self): + """Test that we deny relations on non-existant events""" + channel = self._send_relation( + RelationTypes.ANNOTATION, + EventTypes.Message, + parent_id="foo", + content={"body": "foo", "msgtype": "m.text"}, + ) + self.assertEquals(400, channel.code, channel.json_body) + + # Unless that event is referenced from another event! + self.get_success( + self.hs.get_datastore().db_pool.simple_insert( + table="event_relations", + values={ + "event_id": "bar", + "relates_to_id": "foo", + "relation_type": RelationTypes.THREAD, + }, + desc="test_deny_invalid_event", + ) + ) + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + parent_id="foo", + content={"body": "foo", "msgtype": "m.text"}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + def test_deny_invalid_room(self): + """Test that we deny relations on non-existant events""" + # Create another room and send a message in it. + room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) + res = self.helper.send(room2, body="Hi!", tok=self.user_token) + parent_id = res["event_id"] + + # Attempt to send an annotation to that event. + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A" + ) + self.assertEquals(400, channel.code, channel.json_body) + def test_deny_double_react(self): """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") @@ -99,6 +142,25 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(400, channel.code, channel.json_body) + def test_deny_forked_thread(self): + """It is invalid to start a thread off a thread.""" + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "foo"}, + parent_id=self.parent_id, + ) + self.assertEquals(200, channel.code, channel.json_body) + parent_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "foo"}, + parent_id=parent_id, + ) + self.assertEquals(400, channel.code, channel.json_body) + def test_basic_paginate_relations(self): """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") -- cgit 1.5.1 From 433ee159cbc89f5f00a6c5ef7849007a81d58baf Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 18 Nov 2021 14:45:38 +0000 Subject: Rename `get_refresh_token_for_user_id` to `create_refresh_token_for_user_id` (#11370) --- changelog.d/11370.misc | 1 + synapse/handlers/auth.py | 4 ++-- synapse/handlers/register.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/11370.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11370.misc b/changelog.d/11370.misc new file mode 100644 index 0000000000..13d9f36bf3 --- /dev/null +++ b/changelog.d/11370.misc @@ -0,0 +1 @@ +Rename `get_refresh_token_for_user_id` to `create_refresh_token_for_user_id` to better describe what it does. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b2c84a0fce..4b66a9862f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -790,7 +790,7 @@ class AuthHandler: ( new_refresh_token, new_refresh_token_id, - ) = await self.get_refresh_token_for_user_id( + ) = await self.create_refresh_token_for_user_id( user_id=existing_token.user_id, device_id=existing_token.device_id ) access_token = await self.create_access_token_for_user_id( @@ -832,7 +832,7 @@ class AuthHandler: return True - async def get_refresh_token_for_user_id( + async def create_refresh_token_for_user_id( self, user_id: str, device_id: str, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 6b5c3d1974..b5664e9fde 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -813,7 +813,7 @@ class RegistrationHandler: ( refresh_token, refresh_token_id, - ) = await self._auth_handler.get_refresh_token_for_user_id( + ) = await self._auth_handler.create_refresh_token_for_user_id( user_id, device_id=registered_device_id, ) -- cgit 1.5.1 From 92b75388f520cfec412bf2ff57bcdbfa22d8c01d Mon Sep 17 00:00:00 2001 From: Shay Date: Thu, 18 Nov 2021 10:56:32 -0800 Subject: Remove legacy code related to deprecated `trust_identity_server_for_password_resets` config flag (#11333) * remove code legacy code related to deprecated config flag "trust_identity_server_for_password_resets" from synapse/config/emailconfig.py * remove legacy code supporting depreciated config flag "trust_identity_server_for_password_resets" from synapse/config/registration.py * remove legacy code supporting depreciated config flag "trust_identity_server_for_password_resets" from synapse/handlers/identity.py * add tests to ensure config error is thrown and synapse refuses to start when depreciated config flag is found * add changelog * slightly change behavior to only check for deprecated flag if set to 'true' * Update changelog.d/11333.misc Co-authored-by: reivilibre Co-authored-by: reivilibre --- changelog.d/11333.misc | 1 + synapse/config/emailconfig.py | 33 +++++++-------------------------- synapse/config/registration.py | 4 +--- synapse/handlers/identity.py | 18 ------------------ tests/config/test_load.py | 9 +++++++++ 5 files changed, 18 insertions(+), 47 deletions(-) create mode 100644 changelog.d/11333.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11333.misc b/changelog.d/11333.misc new file mode 100644 index 0000000000..6c1fd560ad --- /dev/null +++ b/changelog.d/11333.misc @@ -0,0 +1 @@ +Remove deprecated `trust_identity_server_for_password_resets` configuration flag. \ No newline at end of file diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index afd65fecd3..510b647c63 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -137,33 +137,14 @@ class EmailConfig(Config): if self.root.registration.account_threepid_delegate_email else ThreepidBehaviour.LOCAL ) - # Prior to Synapse v1.4.0, there was another option that defined whether Synapse would - # use an identity server to password reset tokens on its behalf. We now warn the user - # if they have this set and tell them to use the updated option, while using a default - # identity server in the process. - self.using_identity_server_from_trusted_list = False - if ( - not self.root.registration.account_threepid_delegate_email - and config.get("trust_identity_server_for_password_resets", False) is True - ): - # Use the first entry in self.trusted_third_party_id_servers instead - if self.trusted_third_party_id_servers: - # XXX: It's a little confusing that account_threepid_delegate_email is modified - # both in RegistrationConfig and here. We should factor this bit out - first_trusted_identity_server = self.trusted_third_party_id_servers[0] - - # trusted_third_party_id_servers does not contain a scheme whereas - # account_threepid_delegate_email is expected to. Presume https - self.root.registration.account_threepid_delegate_email = ( - "https://" + first_trusted_identity_server - ) - self.using_identity_server_from_trusted_list = True - else: - raise ConfigError( - "Attempted to use an identity server from" - '"trusted_third_party_id_servers" but it is empty.' - ) + if config.get("trust_identity_server_for_password_resets"): + raise ConfigError( + 'The config option "trust_identity_server_for_password_resets" ' + 'has been replaced by "account_threepid_delegate". ' + "Please consult the sample config at docs/sample_config.yaml for " + "details and update your config file." + ) self.local_threepid_handling_disabled_due_to_email_config = False if ( diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 5379e80715..66382a479e 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -39,9 +39,7 @@ class RegistrationConfig(Config): self.registration_shared_secret = config.get("registration_shared_secret") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) - self.trusted_third_party_id_servers = config.get( - "trusted_third_party_id_servers", ["matrix.org", "vector.im"] - ) + account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 3dbe611f95..c83eaea359 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -464,15 +464,6 @@ class IdentityHandler: if next_link: params["next_link"] = next_link - if self.hs.config.email.using_identity_server_from_trusted_list: - # Warn that a deprecated config option is in use - logger.warning( - 'The config option "trust_identity_server_for_password_resets" ' - 'has been replaced by "account_threepid_delegate". ' - "Please consult the sample config at docs/sample_config.yaml for " - "details and update your config file." - ) - try: data = await self.http_client.post_json_get_json( id_server + "/_matrix/identity/api/v1/validate/email/requestToken", @@ -517,15 +508,6 @@ class IdentityHandler: if next_link: params["next_link"] = next_link - if self.hs.config.email.using_identity_server_from_trusted_list: - # Warn that a deprecated config option is in use - logger.warning( - 'The config option "trust_identity_server_for_password_resets" ' - 'has been replaced by "account_threepid_delegate". ' - "Please consult the sample config at docs/sample_config.yaml for " - "details and update your config file." - ) - try: data = await self.http_client.post_json_get_json( id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken", diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 765258c47a..d8668d56b2 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -94,3 +94,12 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase): # The default Metrics Flags are off by default. config = HomeServerConfig.load_config("", ["-c", self.config_file]) self.assertFalse(config.metrics.metrics_flags.known_servers) + + def test_depreciated_identity_server_flag_throws_error(self): + self.generate_config() + # Needed to ensure that actual key/value pair added below don't end up on a line with a comment + self.add_lines_to_config([" "]) + # Check that presence of "trust_identity_server_for_password" throws config error + self.add_lines_to_config(["trust_identity_server_for_password_resets: true"]) + with self.assertRaises(ConfigError): + HomeServerConfig.load_config("", ["-c", self.config_file]) -- cgit 1.5.1 From 7ffddd819c2e0e7edff141e987372205894084b6 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 18 Nov 2021 14:16:08 -0600 Subject: Prevent historical state from being pushed to an application service via `/transactions` (MSC2716) (#11265) Mark historical state from the MSC2716 `/batch_send` endpoint as `historical` which makes it `backfilled` and have a negative `stream_ordering` so it doesn't get queried by `/transactions`. Fix https://github.com/matrix-org/synapse/issues/11241 Complement tests: https://github.com/matrix-org/complement/pull/221 --- changelog.d/11265.bugfix | 1 + synapse/appservice/api.py | 23 +++++++++++++++++++++-- synapse/handlers/room_batch.py | 2 ++ synapse/handlers/room_member.py | 15 +++++++++++++++ 4 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11265.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11265.bugfix b/changelog.d/11265.bugfix new file mode 100644 index 0000000000..b0e9dfac53 --- /dev/null +++ b/changelog.d/11265.bugfix @@ -0,0 +1 @@ +Prevent [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) historical state events from being pushed to an application service via `/transactions`. diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index d08f6bbd7f..f51b636417 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -231,13 +231,32 @@ class ApplicationServiceApi(SimpleHttpClient): json_body=body, args={"access_token": service.hs_token}, ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "push_bulk to %s succeeded! events=%s", + uri, + [event.get("event_id") for event in events], + ) sent_transactions_counter.labels(service.id).inc() sent_events_counter.labels(service.id).inc(len(events)) return True except CodeMessageException as e: - logger.warning("push_bulk to %s received %s", uri, e.code) + logger.warning( + "push_bulk to %s received code=%s msg=%s", + uri, + e.code, + e.msg, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) except Exception as ex: - logger.warning("push_bulk to %s threw exception %s", uri, ex) + logger.warning( + "push_bulk to %s threw exception(%s) %s args=%s", + uri, + type(ex).__name__, + ex, + ex.args, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) failed_transactions_counter.labels(service.id).inc() return False diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 0723286383..f880aa93d2 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -221,6 +221,7 @@ class RoomBatchHandler: action=membership, content=event_dict["content"], outlier=True, + historical=True, prev_event_ids=[prev_event_id_for_state_chain], # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same @@ -240,6 +241,7 @@ class RoomBatchHandler: ), event_dict, outlier=True, + historical=True, prev_event_ids=[prev_event_id_for_state_chain], # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 08244b690d..a6dbff637f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -268,6 +268,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content: Optional[dict] = None, require_consent: bool = True, outlier: bool = False, + historical: bool = False, ) -> Tuple[str, int]: """ Internal membership update function to get an existing event or create @@ -293,6 +294,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted + back in time around some existing events. This is used to skip + a few checks and mark the event as backfilled. Returns: Tuple of event ID and stream ordering position @@ -337,6 +341,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): auth_event_ids=auth_event_ids, require_consent=require_consent, outlier=outlier, + historical=historical, ) prev_state_ids = await context.get_prev_state_ids() @@ -433,6 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): new_room: bool = False, require_consent: bool = True, outlier: bool = False, + historical: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, ) -> Tuple[str, int]: @@ -454,6 +460,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted + back in time around some existing events. This is used to skip + a few checks and mark the event as backfilled. prev_event_ids: The event IDs to use as the prev events auth_event_ids: The event ids to use as the auth_events for the new event. @@ -487,6 +496,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): new_room=new_room, require_consent=require_consent, outlier=outlier, + historical=historical, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, ) @@ -507,6 +517,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): new_room: bool = False, require_consent: bool = True, outlier: bool = False, + historical: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, ) -> Tuple[str, int]: @@ -530,6 +541,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted + back in time around some existing events. This is used to skip + a few checks and mark the event as backfilled. prev_event_ids: The event IDs to use as the prev events auth_event_ids: The event ids to use as the auth_events for the new event. @@ -657,6 +671,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content=content, require_consent=require_consent, outlier=outlier, + historical=historical, ) latest_event_ids = await self.store.get_prev_events_for_room(room_id) -- cgit 1.5.1 From 7ae559944af1b27e36a6f4423177f51d7b4b3826 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 19 Nov 2021 10:19:32 -0500 Subject: Fix checking whether a room can be published on creation. (#11392) If `room_list_publication_rules` was configured with a rule with a non-wildcard alias and a room was created with an alias then an internal server error would have been thrown. This fixes the error and properly applies the publication rules during room creation. --- changelog.d/11392.bugfix | 1 + synapse/config/room_directory.py | 50 +++++++++++---------- synapse/handlers/room.py | 5 ++- tests/handlers/test_directory.py | 95 ++++++++++++++++++++++++++-------------- 4 files changed, 95 insertions(+), 56 deletions(-) create mode 100644 changelog.d/11392.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/11392.bugfix b/changelog.d/11392.bugfix new file mode 100644 index 0000000000..fb15800327 --- /dev/null +++ b/changelog.d/11392.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.13.0 where creating and publishing a room could cause errors if `room_list_publication_rules` is configured. diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 56981cac79..57316c59b6 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -1,4 +1,5 @@ # Copyright 2018 New Vector Ltd +# Copyright 2021 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + +from synapse.types import JsonDict from synapse.util import glob_to_regex from ._base import Config, ConfigError @@ -20,7 +24,7 @@ from ._base import Config, ConfigError class RoomDirectoryConfig(Config): section = "roomdirectory" - def read_config(self, config, **kwargs): + def read_config(self, config, **kwargs) -> None: self.enable_room_list_search = config.get("enable_room_list_search", True) alias_creation_rules = config.get("alias_creation_rules") @@ -47,7 +51,7 @@ class RoomDirectoryConfig(Config): _RoomDirectoryRule("room_list_publication_rules", {"action": "allow"}) ] - def generate_config_section(self, config_dir_path, server_name, **kwargs): + def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str: return """ # Uncomment to disable searching the public room list. When disabled # blocks searching local and remote room lists for local and remote @@ -113,16 +117,16 @@ class RoomDirectoryConfig(Config): # action: allow """ - def is_alias_creation_allowed(self, user_id, room_id, alias): + def is_alias_creation_allowed(self, user_id: str, room_id: str, alias: str) -> bool: """Checks if the given user is allowed to create the given alias Args: - user_id (str) - room_id (str) - alias (str) + user_id: The user to check. + room_id: The room ID for the alias. + alias: The alias being created. Returns: - boolean: True if user is allowed to create the alias + True if user is allowed to create the alias """ for rule in self._alias_creation_rules: if rule.matches(user_id, room_id, [alias]): @@ -130,16 +134,18 @@ class RoomDirectoryConfig(Config): return False - def is_publishing_room_allowed(self, user_id, room_id, aliases): + def is_publishing_room_allowed( + self, user_id: str, room_id: str, aliases: List[str] + ) -> bool: """Checks if the given user is allowed to publish the room Args: - user_id (str) - room_id (str) - aliases (list[str]): any local aliases associated with the room + user_id: The user ID publishing the room. + room_id: The room being published. + aliases: any local aliases associated with the room Returns: - boolean: True if user can publish room + True if user can publish room """ for rule in self._room_list_publication_rules: if rule.matches(user_id, room_id, aliases): @@ -153,11 +159,11 @@ class _RoomDirectoryRule: creating an alias or publishing a room. """ - def __init__(self, option_name, rule): + def __init__(self, option_name: str, rule: JsonDict): """ Args: - option_name (str): Name of the config option this rule belongs to - rule (dict): The rule as specified in the config + option_name: Name of the config option this rule belongs to + rule: The rule as specified in the config """ action = rule["action"] @@ -181,18 +187,18 @@ class _RoomDirectoryRule: except Exception as e: raise ConfigError("Failed to parse glob into regex") from e - def matches(self, user_id, room_id, aliases): + def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool: """Tests if this rule matches the given user_id, room_id and aliases. Args: - user_id (str) - room_id (str) - aliases (list[str]): The associated aliases to the room. Will be a - single element for testing alias creation, and can be empty for - testing room publishing. + user_id: The user ID to check. + room_id: The room ID to check. + aliases: The associated aliases to the room. Will be a single element + for testing alias creation, and can be empty for testing room + publishing. Returns: - boolean + True if the rule matches. """ # Note: The regexes are anchored at both ends diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f9a099c4f3..88053f9869 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -775,8 +775,11 @@ class RoomCreationHandler: raise SynapseError(403, "Room visibility value not allowed.") if is_public: + room_aliases = [] + if room_alias: + room_aliases.append(room_alias.to_string()) if not self.config.roomdirectory.is_publishing_room_allowed( - user_id, room_id, room_alias + user_id, room_id, room_aliases ): # Let's just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index be008227df..0ea4e753e2 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2021 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - from unittest.mock import Mock import synapse.api.errors import synapse.rest.admin from synapse.api.constants import EventTypes -from synapse.config.room_directory import RoomDirectoryConfig from synapse.rest.client import directory, login, room from synapse.types import RoomAlias, create_requester @@ -394,22 +393,15 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): servlets = [directory.register_servlets, room.register_servlets] - def prepare(self, reactor, clock, hs): - # We cheekily override the config to add custom alias creation rules - config = {} + def default_config(self): + config = super().default_config() + + # Add custom alias creation rules to the config. config["alias_creation_rules"] = [ {"user_id": "*", "alias": "#unofficial_*", "action": "allow"} ] - config["room_list_publication_rules"] = [] - rd_config = RoomDirectoryConfig() - rd_config.read_config(config) - - self.hs.config.roomdirectory.is_alias_creation_allowed = ( - rd_config.is_alias_creation_allowed - ) - - return hs + return config def test_denied(self): room_id = self.helper.create_room_as(self.user_id) @@ -417,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): channel = self.make_request( "PUT", b"directory/room/%23test%3Atest", - ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), + {"room_id": room_id}, ) self.assertEquals(403, channel.code, channel.result) @@ -427,14 +419,35 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): channel = self.make_request( "PUT", b"directory/room/%23unofficial_test%3Atest", - ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), + {"room_id": room_id}, ) self.assertEquals(200, channel.code, channel.result) + def test_denied_during_creation(self): + """A room alias that is not allowed should be rejected during creation.""" + # Invalid room alias. + self.helper.create_room_as( + self.user_id, + expect_code=403, + extra_content={"room_alias_name": "foo"}, + ) -class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): - data = {"room_alias_name": "unofficial_test"} + def test_allowed_during_creation(self): + """A valid room alias should be allowed during creation.""" + room_id = self.helper.create_room_as( + self.user_id, + extra_content={"room_alias_name": "unofficial_test"}, + ) + channel = self.make_request( + "GET", + b"directory/room/%23unofficial_test%3Atest", + ) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(channel.json_body["room_id"], room_id) + + +class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -443,27 +456,30 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): ] hijack_auth = False - def prepare(self, reactor, clock, hs): - self.allowed_user_id = self.register_user("allowed", "pass") - self.allowed_access_token = self.login("allowed", "pass") + data = {"room_alias_name": "unofficial_test"} + allowed_localpart = "allowed" - self.denied_user_id = self.register_user("denied", "pass") - self.denied_access_token = self.login("denied", "pass") + def default_config(self): + config = super().default_config() - # This time we add custom room list publication rules - config = {} - config["alias_creation_rules"] = [] + # Add custom room list publication rules to the config. config["room_list_publication_rules"] = [ + { + "user_id": "@" + self.allowed_localpart + "*", + "alias": "#unofficial_*", + "action": "allow", + }, {"user_id": "*", "alias": "*", "action": "deny"}, - {"user_id": self.allowed_user_id, "alias": "*", "action": "allow"}, ] - rd_config = RoomDirectoryConfig() - rd_config.read_config(config) + return config - self.hs.config.roomdirectory.is_publishing_room_allowed = ( - rd_config.is_publishing_room_allowed - ) + def prepare(self, reactor, clock, hs): + self.allowed_user_id = self.register_user(self.allowed_localpart, "pass") + self.allowed_access_token = self.login(self.allowed_localpart, "pass") + + self.denied_user_id = self.register_user("denied", "pass") + self.denied_access_token = self.login("denied", "pass") return hs @@ -505,10 +521,23 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): self.allowed_user_id, tok=self.allowed_access_token, extra_content=self.data, - is_public=False, + is_public=True, expect_code=200, ) + def test_denied_publication_with_invalid_alias(self): + """ + Try to create a room, register an alias for it, and publish it, + as a user WITH permission to publish rooms. + """ + self.helper.create_room_as( + self.allowed_user_id, + tok=self.allowed_access_token, + extra_content={"room_alias_name": "foo"}, + is_public=True, + expect_code=403, + ) + def test_can_create_as_private_room_after_rejection(self): """ After failing to publish a room with an alias as a user without publish permission, -- cgit 1.5.1 From 6a5dd485bd82b269e7e169c0385290d081eae801 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 23 Nov 2021 06:43:56 -0500 Subject: Refactor the code to inject bundled relations during serialization. (#11408) --- changelog.d/11408.misc | 1 + synapse/events/utils.py | 146 ++++++++++++++++++++++----------------- synapse/handlers/events.py | 2 +- synapse/handlers/message.py | 2 +- synapse/rest/admin/rooms.py | 4 +- synapse/rest/client/relations.py | 6 +- synapse/rest/client/room.py | 2 +- synapse/rest/client/sync.py | 2 +- 8 files changed, 92 insertions(+), 73 deletions(-) create mode 100644 changelog.d/11408.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11408.misc b/changelog.d/11408.misc new file mode 100644 index 0000000000..55ed064672 --- /dev/null +++ b/changelog.d/11408.misc @@ -0,0 +1 @@ +Refactor including the bundled relations when serializing an event. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 6fa631aa1d..e5967c995e 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -392,15 +393,16 @@ class EventClientSerializer: self, event: Union[JsonDict, EventBase], time_now: int, - bundle_aggregations: bool = True, + bundle_relations: bool = True, **kwargs: Any, ) -> JsonDict: """Serializes a single event. Args: - event + event: The event being serialized. time_now: The current time in milliseconds - bundle_aggregations: Whether to bundle in related events + bundle_relations: Whether to include the bundled relations for this + event. **kwargs: Arguments to pass to `serialize_event` Returns: @@ -410,77 +412,93 @@ class EventClientSerializer: if not isinstance(event, EventBase): return event - event_id = event.event_id serialized_event = serialize_event(event, time_now, **kwargs) # If MSC1849 is enabled then we need to look if there are any relations # we need to bundle in with the event. # Do not bundle relations if the event has been redacted if not event.internal_metadata.is_redacted() and ( - self._msc1849_enabled and bundle_aggregations + self._msc1849_enabled and bundle_relations ): - annotations = await self.store.get_aggregation_groups_for_event(event_id) - references = await self.store.get_relations_for_event( - event_id, RelationTypes.REFERENCE, direction="f" - ) - - if annotations.chunk: - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.ANNOTATION] = annotations.to_dict() - - if references.chunk: - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REFERENCE] = references.to_dict() - - edit = None - if event.type == EventTypes.Message: - edit = await self.store.get_applicable_edit(event_id) - - if edit: - # If there is an edit replace the content, preserving existing - # relations. - - # Ensure we take copies of the edit content, otherwise we risk modifying - # the original event. - edit_content = edit.content.copy() - - # Unfreeze the event content if necessary, so that we may modify it below - edit_content = unfreeze(edit_content) - serialized_event["content"] = edit_content.get("m.new_content", {}) - - # Check for existing relations - relations = event.content.get("m.relates_to") - if relations: - # Keep the relations, ensuring we use a dict copy of the original - serialized_event["content"]["m.relates_to"] = relations.copy() - else: - serialized_event["content"].pop("m.relates_to", None) - - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REPLACE] = { - "event_id": edit.event_id, - "origin_server_ts": edit.origin_server_ts, - "sender": edit.sender, - } - - # If this event is the start of a thread, include a summary of the replies. - if self._msc3440_enabled: - ( - thread_count, - latest_thread_event, - ) = await self.store.get_thread_summary(event_id) - if latest_thread_event: - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.THREAD] = { - # Don't bundle aggregations as this could recurse forever. - "latest_event": await self.serialize_event( - latest_thread_event, time_now, bundle_aggregations=False - ), - "count": thread_count, - } + await self._injected_bundled_relations(event, time_now, serialized_event) return serialized_event + async def _injected_bundled_relations( + self, event: EventBase, time_now: int, serialized_event: JsonDict + ) -> None: + """Potentially injects bundled relations into the unsigned portion of the serialized event. + + Args: + event: The event being serialized. + time_now: The current time in milliseconds + serialized_event: The serialized event which may be modified. + + """ + event_id = event.event_id + + # The bundled relations to include. + relations = {} + + annotations = await self.store.get_aggregation_groups_for_event(event_id) + if annotations.chunk: + relations[RelationTypes.ANNOTATION] = annotations.to_dict() + + references = await self.store.get_relations_for_event( + event_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + relations[RelationTypes.REFERENCE] = references.to_dict() + + edit = None + if event.type == EventTypes.Message: + edit = await self.store.get_applicable_edit(event_id) + + if edit: + # If there is an edit replace the content, preserving existing + # relations. + + # Ensure we take copies of the edit content, otherwise we risk modifying + # the original event. + edit_content = edit.content.copy() + + # Unfreeze the event content if necessary, so that we may modify it below + edit_content = unfreeze(edit_content) + serialized_event["content"] = edit_content.get("m.new_content", {}) + + # Check for existing relations + relates_to = event.content.get("m.relates_to") + if relates_to: + # Keep the relations, ensuring we use a dict copy of the original + serialized_event["content"]["m.relates_to"] = relates_to.copy() + else: + serialized_event["content"].pop("m.relates_to", None) + + relations[RelationTypes.REPLACE] = { + "event_id": edit.event_id, + "origin_server_ts": edit.origin_server_ts, + "sender": edit.sender, + } + + # If this event is the start of a thread, include a summary of the replies. + if self._msc3440_enabled: + ( + thread_count, + latest_thread_event, + ) = await self.store.get_thread_summary(event_id) + if latest_thread_event: + relations[RelationTypes.THREAD] = { + # Don't bundle relations as this could recurse forever. + "latest_event": await self.serialize_event( + latest_thread_event, time_now, bundle_relations=False + ), + "count": thread_count, + } + + # If any bundled relations were found, include them. + if relations: + serialized_event["unsigned"].setdefault("m.relations", {}).update(relations) + async def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any ) -> List[JsonDict]: diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 1f64534a8a..b4ff935546 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -124,7 +124,7 @@ class EventStreamHandler: as_client_event=as_client_event, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. - bundle_aggregations=False, + bundle_relations=False, ) chunk = { diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 22dd4cf5fd..95b4fad3c6 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -252,7 +252,7 @@ class MessageHandler: now, # We don't bother bundling aggregations in when asked for state # events, as clients won't use them. - bundle_aggregations=False, + bundle_relations=False, ) return events diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 5b8ec1e5ca..a89dda1ba5 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -448,7 +448,7 @@ class RoomStateRestServlet(RestServlet): now, # We don't bother bundling aggregations in when asked for state # events, as clients won't use them. - bundle_aggregations=False, + bundle_relations=False, ) ret = {"state": room_state} @@ -778,7 +778,7 @@ class RoomEventContextServlet(RestServlet): results["state"], time_now, # No need to bundle aggregations for state events - bundle_aggregations=False, + bundle_relations=False, ) return 200, results diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 184cfbe196..45e9f1dd90 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -224,17 +224,17 @@ class RelationPaginationServlet(RestServlet): ) now = self.clock.time_msec() - # We set bundle_aggregations to False when retrieving the original + # We set bundle_relations to False when retrieving the original # event because we want the content before relations were applied to # it. original_event = await self._event_serializer.serialize_event( - event, now, bundle_aggregations=False + event, now, bundle_relations=False ) # Similarly, we don't allow relations to be applied to relations, so we # return the original relations without any aggregations on top of them # here. serialized_events = await self._event_serializer.serialize_events( - events, now, bundle_aggregations=False + events, now, bundle_relations=False ) return_value = pagination_chunk.to_dict() diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 03a353d53c..955d4e8641 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -719,7 +719,7 @@ class RoomEventContextServlet(RestServlet): results["state"], time_now, # No need to bundle aggregations for state events - bundle_aggregations=False, + bundle_relations=False, ) return 200, results diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 8c0fdb1940..b6a2485732 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -522,7 +522,7 @@ class SyncRestServlet(RestServlet): time_now=time_now, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. - bundle_aggregations=False, + bundle_relations=False, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, -- cgit 1.5.1 From f25c75d376ff8517e3245e06abdacf813606bd1f Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 23 Nov 2021 17:01:34 +0000 Subject: Rename unstable `access_token_lifetime` configuration option to `refreshable_access_token_lifetime` to make it clear it only concerns refreshable access tokens. (#11388) --- changelog.d/11388.misc | 1 + synapse/config/registration.py | 23 +++++++++++++++-------- synapse/handlers/register.py | 8 ++++++-- synapse/rest/client/login.py | 14 ++++++++++---- synapse/rest/client/register.py | 4 +++- tests/rest/client/test_auth.py | 2 +- 6 files changed, 36 insertions(+), 16 deletions(-) create mode 100644 changelog.d/11388.misc (limited to 'synapse/handlers') diff --git a/changelog.d/11388.misc b/changelog.d/11388.misc new file mode 100644 index 0000000000..7ce7ad0498 --- /dev/null +++ b/changelog.d/11388.misc @@ -0,0 +1 @@ +Rename unstable `access_token_lifetime` configuration option to `refreshable_access_token_lifetime` to make it clear it only concerns refreshable access tokens. \ No newline at end of file diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 66382a479e..61e569d412 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -112,25 +112,32 @@ class RegistrationConfig(Config): session_lifetime = self.parse_duration(session_lifetime) self.session_lifetime = session_lifetime - # The `access_token_lifetime` applies for tokens that can be renewed + # The `refreshable_access_token_lifetime` applies for tokens that can be renewed # using a refresh token, as per MSC2918. If it is `None`, the refresh # token mechanism is disabled. # # Since it is incompatible with the `session_lifetime` mechanism, it is set to # `None` by default if a `session_lifetime` is set. - access_token_lifetime = config.get( - "access_token_lifetime", "5m" if session_lifetime is None else None + refreshable_access_token_lifetime = config.get( + "refreshable_access_token_lifetime", + "5m" if session_lifetime is None else None, ) - if access_token_lifetime is not None: - access_token_lifetime = self.parse_duration(access_token_lifetime) - self.access_token_lifetime = access_token_lifetime + if refreshable_access_token_lifetime is not None: + refreshable_access_token_lifetime = self.parse_duration( + refreshable_access_token_lifetime + ) + self.refreshable_access_token_lifetime = refreshable_access_token_lifetime - if session_lifetime is not None and access_token_lifetime is not None: + if ( + session_lifetime is not None + and refreshable_access_token_lifetime is not None + ): raise ConfigError( "The refresh token mechanism is incompatible with the " "`session_lifetime` option. Consider disabling the " "`session_lifetime` option or disabling the refresh token " - "mechanism by removing the `access_token_lifetime` option." + "mechanism by removing the `refreshable_access_token_lifetime` " + "option." ) # The fallback template used for authenticating using a registration token diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index b5664e9fde..448a36108e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -116,7 +116,9 @@ class RegistrationHandler: self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.registration.session_lifetime - self.access_token_lifetime = hs.config.registration.access_token_lifetime + self.refreshable_access_token_lifetime = ( + hs.config.registration.refreshable_access_token_lifetime + ) init_counters_for_auth_provider("") @@ -817,7 +819,9 @@ class RegistrationHandler: user_id, device_id=registered_device_id, ) - valid_until_ms = self.clock.time_msec() + self.access_token_lifetime + valid_until_ms = ( + self.clock.time_msec() + self.refreshable_access_token_lifetime + ) access_token = await self._auth_handler.create_access_token_for_user_id( user_id, diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 00e65c66ac..67e03dca04 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -81,7 +81,9 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None + self._msc2918_enabled = ( + hs.config.registration.refreshable_access_token_lifetime is not None + ) self.auth = hs.get_auth() @@ -453,7 +455,9 @@ class RefreshTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() self._clock = hs.get_clock() - self.access_token_lifetime = hs.config.registration.access_token_lifetime + self.refreshable_access_token_lifetime = ( + hs.config.registration.refreshable_access_token_lifetime + ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) @@ -463,7 +467,9 @@ class RefreshTokenServlet(RestServlet): if not isinstance(token, str): raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) - valid_until_ms = self._clock.time_msec() + self.access_token_lifetime + valid_until_ms = ( + self._clock.time_msec() + self.refreshable_access_token_lifetime + ) access_token, refresh_token = await self._auth_handler.refresh_token( token, valid_until_ms ) @@ -562,7 +568,7 @@ class CasTicketServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) - if hs.config.registration.access_token_lifetime is not None: + if hs.config.registration.refreshable_access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas.cas_enabled: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index bf3cb34146..d2b11e39d9 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -420,7 +420,9 @@ class RegisterRestServlet(RestServlet): self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_enabled = self.hs.config.registration.enable_registration - self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None + self._msc2918_enabled = ( + hs.config.registration.refreshable_access_token_lifetime is not None + ) self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index e2fcbdc63a..8552671431 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -598,7 +598,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response.json_body["refresh_token"], ) - @override_config({"access_token_lifetime": "1m"}) + @override_config({"refreshable_access_token_lifetime": "1m"}) def test_refresh_token_expiration(self): """ The access token should have some time as specified in the config. -- cgit 1.5.1