diff options
author | Olivier Wilkinson (reivilibre) <oliverw@matrix.org> | 2021-12-20 11:42:10 +0000 |
---|---|---|
committer | Olivier Wilkinson (reivilibre) <oliverw@matrix.org> | 2021-12-20 11:42:10 +0000 |
commit | 8325ddd0bc9e9771582977ff9c9d54210d21c541 (patch) | |
tree | e7f0b6bc6e986c54f80a18f93de5303b5a0231e0 /synapse/rest | |
parent | Merge branch 'develop' into rei/gsfg_1 (diff) | |
parent | Add type hints to `synapse/tests/rest/admin` (#11590) (diff) | |
download | synapse-8325ddd0bc9e9771582977ff9c9d54210d21c541.tar.xz |
Merge branch 'develop' into rei/gsfg_1
Diffstat (limited to 'synapse/rest')
36 files changed, 1787 insertions, 847 deletions
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index e04af705eb..cebdeecb81 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from synapse.http.server import HttpServer, JsonResource from synapse.rest import admin @@ -62,6 +62,8 @@ from synapse.rest.client import ( if TYPE_CHECKING: from synapse.server import HomeServer +RegisterServletsFunc = Callable[["HomeServer", HttpServer], None] + class ClientRestResource(JsonResource): """Matrix Client API REST resource. diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 70514e814f..701c609c12 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -17,6 +17,7 @@ import logging import platform +from http import HTTPStatus from typing import TYPE_CHECKING, Optional, Tuple import synapse @@ -25,6 +26,11 @@ from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.rest.admin.background_updates import ( + BackgroundUpdateEnabledRestServlet, + BackgroundUpdateRestServlet, + BackgroundUpdateStartJobRestServlet, +) from synapse.rest.admin.devices import ( DeleteDevicesRestServlet, DeviceRestServlet, @@ -34,6 +40,10 @@ from synapse.rest.admin.event_reports import ( EventReportDetailRestServlet, EventReportsRestServlet, ) +from synapse.rest.admin.federation import ( + DestinationsRestServlet, + ListDestinationsRestServlet, +) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.registration_tokens import ( @@ -42,6 +52,9 @@ from synapse.rest.admin.registration_tokens import ( RegistrationTokenRestServlet, ) from synapse.rest.admin.rooms import ( + BlockRoomRestServlet, + DeleteRoomStatusByDeleteIdRestServlet, + DeleteRoomStatusByRoomIdRestServlet, ForwardExtremitiesRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, @@ -49,6 +62,7 @@ from synapse.rest.admin.rooms import ( RoomEventContextServlet, RoomMembersRestServlet, RoomRestServlet, + RoomRestV2Servlet, RoomStateRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet @@ -89,12 +103,12 @@ class VersionServlet(RestServlet): } def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - return 200, self.res + return HTTPStatus.OK, self.res class PurgeHistoryRestServlet(RestServlet): PATTERNS = admin_patterns( - "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?" + "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]*))?$" ) def __init__(self, hs: "HomeServer"): @@ -121,7 +135,7 @@ class PurgeHistoryRestServlet(RestServlet): event = await self.store.get_event(event_id) if event.room_id != room_id: - raise SynapseError(400, "Event is for wrong room.") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Event is for wrong room.") # RoomStreamToken expects [int] not Optional[int] assert event.internal_metadata.stream_ordering is not None @@ -135,7 +149,9 @@ class PurgeHistoryRestServlet(RestServlet): ts = body["purge_up_to_ts"] if not isinstance(ts, int): raise SynapseError( - 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON + HTTPStatus.BAD_REQUEST, + "purge_up_to_ts must be an int", + errcode=Codes.BAD_JSON, ) stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) @@ -151,7 +167,9 @@ class PurgeHistoryRestServlet(RestServlet): stream_ordering, ) raise SynapseError( - 404, "there is no event to be purged", errcode=Codes.NOT_FOUND + HTTPStatus.NOT_FOUND, + "there is no event to be purged", + errcode=Codes.NOT_FOUND, ) (stream, topo, _event_id) = r token = "t%d-%d" % (topo, stream) @@ -164,7 +182,7 @@ class PurgeHistoryRestServlet(RestServlet): ) else: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "must specify purge_up_to_event_id or purge_up_to_ts", errcode=Codes.BAD_JSON, ) @@ -173,11 +191,11 @@ class PurgeHistoryRestServlet(RestServlet): room_id, token, delete_local_events=delete_local_events ) - return 200, {"purge_id": purge_id} + return HTTPStatus.OK, {"purge_id": purge_id} class PurgeHistoryStatusRestServlet(RestServlet): - PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)") + PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.pagination_handler = hs.get_pagination_handler() @@ -192,7 +210,7 @@ class PurgeHistoryStatusRestServlet(RestServlet): if purge_status is None: raise NotFoundError("purge id '%s' not found" % purge_id) - return 200, purge_status.asdict() + return HTTPStatus.OK, purge_status.asdict() ######################################################################################## @@ -216,10 +234,14 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: Register all the admin servlets. """ register_servlets_for_client_rest_resource(hs, http_server) + BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) + RoomRestV2Servlet(hs).register(http_server) RoomMembersRestServlet(hs).register(http_server) + DeleteRoomStatusByDeleteIdRestServlet(hs).register(http_server) + DeleteRoomStatusByRoomIdRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) @@ -243,10 +265,15 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListRegistrationTokensRestServlet(hs).register(http_server) NewRegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server) + DestinationsRestServlet(hs).register(http_server) + ListDestinationsRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if hs.config.worker.worker_app is None: SendServerNoticeServlet(hs).register(http_server) + BackgroundUpdateEnabledRestServlet(hs).register(http_server) + BackgroundUpdateRestServlet(hs).register(http_server) + BackgroundUpdateStartJobRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index d9a2f6ca15..399b205aaf 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -13,6 +13,7 @@ # limitations under the License. import re +from http import HTTPStatus from typing import Iterable, Pattern from synapse.api.auth import Auth @@ -62,4 +63,4 @@ async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None: """ is_admin = await auth.is_server_admin(user_id) if not is_admin: - raise AuthError(403, "You are not a server admin") + raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py new file mode 100644 index 0000000000..6ec00ce0b9 --- /dev/null +++ b/synapse/rest/admin/background_updates.py @@ -0,0 +1,172 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class BackgroundUpdateEnabledRestServlet(RestServlet): + """Allows temporarily disabling background updates""" + + PATTERNS = admin_patterns("/background_updates/enabled$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._data_stores = hs.get_datastores() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + # We need to check that all configured databases have updates enabled. + # (They *should* all be in sync.) + enabled = all(db.updates.enabled for db in self._data_stores.databases) + + return HTTPStatus.OK, {"enabled": enabled} + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + body = parse_json_object_from_request(request) + + enabled = body.get("enabled", True) + + if not isinstance(enabled, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'enabled' parameter must be a boolean" + ) + + for db in self._data_stores.databases: + db.updates.enabled = enabled + + # If we're re-enabling them ensure that we start the background + # process again. + if enabled: + db.updates.start_doing_background_updates() + + return HTTPStatus.OK, {"enabled": enabled} + + +class BackgroundUpdateRestServlet(RestServlet): + """Fetch information about background updates""" + + PATTERNS = admin_patterns("/background_updates/status$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._data_stores = hs.get_datastores() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + # We need to check that all configured databases have updates enabled. + # (They *should* all be in sync.) + enabled = all(db.updates.enabled for db in self._data_stores.databases) + + current_updates = {} + + for db in self._data_stores.databases: + update = db.updates.get_current_update() + if not update: + continue + + current_updates[db.name()] = { + "name": update.name, + "total_item_count": update.total_item_count, + "total_duration_ms": update.total_duration_ms, + "average_items_per_ms": update.average_items_per_ms(), + } + + return HTTPStatus.OK, {"enabled": enabled, "current_updates": current_updates} + + +class BackgroundUpdateStartJobRestServlet(RestServlet): + """Allows to start specific background updates""" + + PATTERNS = admin_patterns("/background_updates/start_job$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["job_name"]) + + job_name = body["job_name"] + + if job_name == "populate_stats_process_rooms": + jobs = [ + { + "update_name": "populate_stats_process_rooms", + "progress_json": "{}", + }, + ] + elif job_name == "regenerate_directory": + jobs = [ + { + "update_name": "populate_user_directory_createtables", + "progress_json": "{}", + "depends_on": "", + }, + { + "update_name": "populate_user_directory_process_rooms", + "progress_json": "{}", + "depends_on": "populate_user_directory_createtables", + }, + { + "update_name": "populate_user_directory_process_users", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_rooms", + }, + { + "update_name": "populate_user_directory_cleanup", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_users", + }, + ] + else: + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid job_name") + + try: + await self._store.db_pool.simple_insert_many( + table="background_updates", + values=jobs, + desc=f"admin_api_run_{job_name}", + ) + except self._store.db_pool.engine.module.IntegrityError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Job %s is already in queue of background updates." % (job_name,), + ) + + self._store.db_pool.updates.start_doing_background_updates() + + return HTTPStatus.OK, {} diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 80fbf32f17..d9905ff560 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError @@ -41,10 +42,10 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str, device_id: str @@ -52,8 +53,8 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -62,7 +63,9 @@ class DeviceRestServlet(RestServlet): device = await self.device_handler.get_device( target_user.to_string(), device_id ) - return 200, device + if device is None: + raise NotFoundError("No device found") + return HTTPStatus.OK, device async def on_DELETE( self, request: SynapseRequest, user_id: str, device_id: str @@ -70,15 +73,15 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: raise NotFoundError("Unknown user") await self.device_handler.delete_device(target_user.to_string(), device_id) - return 200, {} + return HTTPStatus.OK, {} async def on_PUT( self, request: SynapseRequest, user_id: str, device_id: str @@ -86,8 +89,8 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -97,7 +100,7 @@ class DeviceRestServlet(RestServlet): await self.device_handler.update_device( target_user.to_string(), device_id, body ) - return 200, {} + return HTTPStatus.OK, {} class DevicesRestServlet(RestServlet): @@ -108,14 +111,10 @@ class DevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2") def __init__(self, hs: "HomeServer"): - """ - Args: - hs: server - """ - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -123,15 +122,15 @@ class DevicesRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: raise NotFoundError("Unknown user") devices = await self.device_handler.get_devices_by_user(target_user.to_string()) - return 200, {"devices": devices, "total": len(devices)} + return HTTPStatus.OK, {"devices": devices, "total": len(devices)} class DeleteDevicesRestServlet(RestServlet): @@ -143,10 +142,10 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, user_id: str @@ -154,8 +153,8 @@ class DeleteDevicesRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -167,4 +166,4 @@ class DeleteDevicesRestServlet(RestServlet): await self.device_handler.delete_devices( target_user.to_string(), body["devices"] ) - return 200, {} + return HTTPStatus.OK, {} diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index bbfcaf723b..38477f8ead 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -51,7 +52,6 @@ class EventReportsRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -66,21 +66,23 @@ class EventReportsRestServlet(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "The start parameter must be a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "The limit parameter must be a positive integer.", errcode=Codes.INVALID_PARAM, ) if direction not in ("f", "b"): raise SynapseError( - 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "Unknown direction: %s" % (direction,), + errcode=Codes.INVALID_PARAM, ) event_reports, total = await self.store.get_event_reports_paginate( @@ -90,7 +92,7 @@ class EventReportsRestServlet(RestServlet): if (start + limit) < total: ret["next_token"] = start + len(event_reports) - return 200, ret + return HTTPStatus.OK, ret class EventReportDetailRestServlet(RestServlet): @@ -112,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -127,13 +128,17 @@ class EventReportDetailRestServlet(RestServlet): try: resolved_report_id = int(report_id) except ValueError: - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) if resolved_report_id < 0: - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) ret = await self.store.get_event_report(resolved_report_id) if not ret: raise NotFoundError("Event report not found") - return 200, ret + return HTTPStatus.OK, ret diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py new file mode 100644 index 0000000000..50d88c9109 --- /dev/null +++ b/synapse/rest/admin/federation.py @@ -0,0 +1,135 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.storage.databases.main.transactions import DestinationSortOrder +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ListDestinationsRestServlet(RestServlet): + """Get request to list all destinations. + This needs user to have administrator access in Synapse. + + GET /_synapse/admin/v1/federation/destinations?from=0&limit=10 + + returns: + 200 OK with list of destinations if success otherwise an error. + + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `destination` can be used to filter by destination. + The parameter `order_by` can be used to order the result. + """ + + PATTERNS = admin_patterns("/federation/destinations$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + destination = parse_string(request, "destination") + + order_by = parse_string( + request, + "order_by", + default=DestinationSortOrder.DESTINATION.value, + allowed_values=[dest.value for dest in DestinationSortOrder], + ) + + direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) + + destinations, total = await self._store.get_destinations_paginate( + start, limit, destination, order_by, direction + ) + response = {"destinations": destinations, "total": total} + if (start + limit) < total: + response["next_token"] = str(start + len(destinations)) + + return HTTPStatus.OK, response + + +class DestinationsRestServlet(RestServlet): + """Get details of a destination. + This needs user to have administrator access in Synapse. + + GET /_synapse/admin/v1/federation/destinations/<destination> + + returns: + 200 OK with details of a destination if success otherwise an error. + """ + + PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]*)$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET( + self, request: SynapseRequest, destination: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + destination_retry_timings = await self._store.get_destination_retry_timings( + destination + ) + + if not destination_retry_timings: + raise NotFoundError("Unknown destination") + + last_successful_stream_ordering = ( + await self._store.get_destination_last_successful_stream_ordering( + destination + ) + ) + + response = { + "destination": destination, + "failure_ts": destination_retry_timings.failure_ts, + "retry_last_ts": destination_retry_timings.retry_last_ts, + "retry_interval": destination_retry_timings.retry_interval, + "last_successful_stream_ordering": last_successful_stream_ordering, + } + + return HTTPStatus.OK, response diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py index 68a3ba3cb7..cd697e180e 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import SynapseError @@ -29,7 +30,7 @@ logger = logging.getLogger(__name__) class DeleteGroupAdminRestServlet(RestServlet): """Allows deleting of local groups""" - PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)") + PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.group_server = hs.get_groups_server_handler() @@ -43,7 +44,7 @@ class DeleteGroupAdminRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) if not self.is_mine_id(group_id): - raise SynapseError(400, "Can only delete local groups") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local groups") await self.group_server.delete_group(group_id, requester.user.to_string()) - return 200, {} + return HTTPStatus.OK, {} diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 30a687d234..7236e4027f 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -14,9 +14,10 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -40,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet): """ PATTERNS = [ - *admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine$"), + *admin_patterns("/room/(?P<room_id>[^/]*)/media/quarantine$"), # This path kept around for legacy reasons - *admin_patterns("/quarantine_media/(?P<room_id>[^/]+)"), + *admin_patterns("/quarantine_media/(?P<room_id>[^/]*)$"), ] def __init__(self, hs: "HomeServer"): @@ -62,7 +63,7 @@ class QuarantineMediaInRoom(RestServlet): room_id, requester.user.to_string() ) - return 200, {"num_quarantined": num_quarantined} + return HTTPStatus.OK, {"num_quarantined": num_quarantined} class QuarantineMediaByUser(RestServlet): @@ -70,7 +71,7 @@ class QuarantineMediaByUser(RestServlet): this server. """ - PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine$") + PATTERNS = admin_patterns("/user/(?P<user_id>[^/]*)/media/quarantine$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -89,7 +90,7 @@ class QuarantineMediaByUser(RestServlet): user_id, requester.user.to_string() ) - return 200, {"num_quarantined": num_quarantined} + return HTTPStatus.OK, {"num_quarantined": num_quarantined} class QuarantineMediaByID(RestServlet): @@ -98,7 +99,7 @@ class QuarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" + "/media/quarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -118,7 +119,7 @@ class QuarantineMediaByID(RestServlet): server_name, media_id, requester.user.to_string() ) - return 200, {} + return HTTPStatus.OK, {} class UnquarantineMediaByID(RestServlet): @@ -127,7 +128,7 @@ class UnquarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/unquarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" + "/media/unquarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -137,8 +138,7 @@ class UnquarantineMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, server_name: str, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info( "Remove from quarantine local media by ID: %s/%s", server_name, media_id @@ -147,13 +147,13 @@ class UnquarantineMediaByID(RestServlet): # Remove from quarantine this media id await self.store.quarantine_media_by_id(server_name, media_id, None) - return 200, {} + return HTTPStatus.OK, {} class ProtectMediaByID(RestServlet): """Protect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)") + PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -162,21 +162,20 @@ class ProtectMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Protecting local media by ID: %s", media_id) # Protect this media id await self.store.mark_local_media_as_safe(media_id, safe=True) - return 200, {} + return HTTPStatus.OK, {} class UnprotectMediaByID(RestServlet): """Unprotect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]+)") + PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -185,21 +184,20 @@ class UnprotectMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Unprotecting local media by ID: %s", media_id) # Unprotect this media id await self.store.mark_local_media_as_safe(media_id, safe=False) - return 200, {} + return HTTPStatus.OK, {} class ListMediaInRoom(RestServlet): """Lists all of the media in a given room.""" - PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media$") + PATTERNS = admin_patterns("/room/(?P<room_id>[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -208,14 +206,11 @@ class ListMediaInRoom(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - is_admin = await self.auth.is_server_admin(requester.user) - if not is_admin: - raise AuthError(403, "You are not a server admin") + await assert_requester_is_admin(self.auth, request) local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id) - return 200, {"local": local_mxcs, "remote": remote_mxcs} + return HTTPStatus.OK, {"local": local_mxcs, "remote": remote_mxcs} class PurgeMediaCacheRestServlet(RestServlet): @@ -233,13 +228,13 @@ class PurgeMediaCacheRestServlet(RestServlet): if before_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts must be a positive integer.", errcode=Codes.INVALID_PARAM, ) elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts you provided is from the year 1970. " + "Double check that you are providing a timestamp in milliseconds.", errcode=Codes.INVALID_PARAM, @@ -247,13 +242,13 @@ class PurgeMediaCacheRestServlet(RestServlet): ret = await self.media_repository.delete_old_remote_media(before_ts) - return 200, ret + return HTTPStatus.OK, ret class DeleteMediaByID(RestServlet): """Delete local media by a given ID. Removes it from this server.""" - PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)") + PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -267,7 +262,7 @@ class DeleteMediaByID(RestServlet): await assert_requester_is_admin(self.auth, request) if self.server_name != server_name: - raise SynapseError(400, "Can only delete local media") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") if await self.store.get_local_media(media_id) is None: raise NotFoundError("Unknown media") @@ -277,7 +272,7 @@ class DeleteMediaByID(RestServlet): deleted_media, total = await self.media_repository.delete_local_media_ids( [media_id] ) - return 200, {"deleted_media": deleted_media, "total": total} + return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} class DeleteMediaByDateSize(RestServlet): @@ -285,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet): timestamp and size. """ - PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete$") + PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -304,26 +299,26 @@ class DeleteMediaByDateSize(RestServlet): if before_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts must be a positive integer.", errcode=Codes.INVALID_PARAM, ) elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts you provided is from the year 1970. " + "Double check that you are providing a timestamp in milliseconds.", errcode=Codes.INVALID_PARAM, ) if size_gt < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter size_gt must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if self.server_name != server_name: - raise SynapseError(400, "Can only delete local media") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") logging.info( "Deleting local media by timestamp: %s, size larger than: %s, keep profile media: %s" @@ -333,7 +328,7 @@ class DeleteMediaByDateSize(RestServlet): deleted_media, total = await self.media_repository.delete_old_local_media( before_ts, size_gt, keep_profiles ) - return 200, {"deleted_media": deleted_media, "total": total} + return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} class UserMediaRestServlet(RestServlet): @@ -352,7 +347,7 @@ class UserMediaRestServlet(RestServlet): media that exist given for this user """ - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -369,7 +364,7 @@ class UserMediaRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") user = await self.store.get_user_by_id(user_id) if user is None: @@ -380,14 +375,14 @@ class UserMediaRestServlet(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -402,16 +397,7 @@ class UserMediaRestServlet(RestServlet): request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") @@ -425,7 +411,7 @@ class UserMediaRestServlet(RestServlet): if (start + limit) < total: ret["next_token"] = start + len(media) - return 200, ret + return HTTPStatus.OK, ret async def on_DELETE( self, request: SynapseRequest, user_id: str @@ -436,7 +422,7 @@ class UserMediaRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") user = await self.store.get_user_by_id(user_id) if user is None: @@ -447,14 +433,14 @@ class UserMediaRestServlet(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -469,16 +455,7 @@ class UserMediaRestServlet(RestServlet): request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") @@ -492,7 +469,7 @@ class UserMediaRestServlet(RestServlet): ([row["media_id"] for row in media]) ) - return 200, {"deleted_media": deleted_media, "total": total} + return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index aba48f6e7b..04948b6408 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -14,6 +14,7 @@ import logging import string +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -69,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -77,7 +77,7 @@ class ListRegistrationTokensRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) valid = parse_boolean(request, "valid") token_list = await self.store.get_registration_tokens(valid) - return 200, {"registration_tokens": token_list} + return HTTPStatus.OK, {"registration_tokens": token_list} class NewRegistrationTokenRestServlet(RestServlet): @@ -108,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/new$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -123,16 +122,20 @@ class NewRegistrationTokenRestServlet(RestServlet): if "token" in body: token = body["token"] if not isinstance(token, str): - raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "token must be a string", + Codes.INVALID_PARAM, + ) if not (0 < len(token) <= 64): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "token must not be empty and must not be longer than 64 characters", Codes.INVALID_PARAM, ) if not set(token).issubset(self.allowed_chars_set): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "token must consist only of characters matched by the regex [A-Za-z0-9-_]", Codes.INVALID_PARAM, ) @@ -142,11 +145,13 @@ class NewRegistrationTokenRestServlet(RestServlet): length = body.get("length", 16) if not isinstance(length, int): raise SynapseError( - 400, "length must be an integer", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "length must be an integer", + Codes.INVALID_PARAM, ) if not (0 < length <= 64): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "length must be greater than zero and not greater than 64", Codes.INVALID_PARAM, ) @@ -162,7 +167,7 @@ class NewRegistrationTokenRestServlet(RestServlet): or (isinstance(uses_allowed, int) and uses_allowed >= 0) ): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "uses_allowed must be a non-negative integer or null", Codes.INVALID_PARAM, ) @@ -170,11 +175,15 @@ class NewRegistrationTokenRestServlet(RestServlet): expiry_time = body.get("expiry_time", None) if not isinstance(expiry_time, (int, type(None))): raise SynapseError( - 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must be an integer or null", + Codes.INVALID_PARAM, ) if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): raise SynapseError( - 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must not be in the past", + Codes.INVALID_PARAM, ) created = await self.store.create_registration_token( @@ -182,7 +191,9 @@ class NewRegistrationTokenRestServlet(RestServlet): ) if not created: raise SynapseError( - 400, f"Token already exists: {token}", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + f"Token already exists: {token}", + Codes.INVALID_PARAM, ) resp = { @@ -192,7 +203,7 @@ class NewRegistrationTokenRestServlet(RestServlet): "completed": 0, "expiry_time": expiry_time, } - return 200, resp + return HTTPStatus.OK, resp class RegistrationTokenRestServlet(RestServlet): @@ -247,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.clock = hs.get_clock() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -261,7 +271,7 @@ class RegistrationTokenRestServlet(RestServlet): if token_info is None: raise NotFoundError(f"No such registration token: {token}") - return 200, token_info + return HTTPStatus.OK, token_info async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: """Update a registration token.""" @@ -277,7 +287,7 @@ class RegistrationTokenRestServlet(RestServlet): or (isinstance(uses_allowed, int) and uses_allowed >= 0) ): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "uses_allowed must be a non-negative integer or null", Codes.INVALID_PARAM, ) @@ -287,11 +297,15 @@ class RegistrationTokenRestServlet(RestServlet): expiry_time = body["expiry_time"] if not isinstance(expiry_time, (int, type(None))): raise SynapseError( - 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must be an integer or null", + Codes.INVALID_PARAM, ) if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): raise SynapseError( - 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must not be in the past", + Codes.INVALID_PARAM, ) new_attributes["expiry_time"] = expiry_time @@ -307,7 +321,7 @@ class RegistrationTokenRestServlet(RestServlet): if token_info is None: raise NotFoundError(f"No such registration token: {token}") - return 200, token_info + return HTTPStatus.OK, token_info async def on_DELETE( self, request: SynapseRequest, token: str @@ -316,6 +330,6 @@ class RegistrationTokenRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if await self.store.delete_registration_token(token): - return 200, {} + return HTTPStatus.OK, {} raise NotFoundError(f"No such registration token: {token}") diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 05c5b4bf0c..17c6df1cc8 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from http import HTTPStatus -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, cast from urllib import parse as urlparse from synapse.api.constants import EventTypes, JoinRules, Membership @@ -34,7 +34,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, RoomID, UserID, create_requester from synapse.util import json_decoder if TYPE_CHECKING: @@ -46,6 +46,139 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class RoomRestV2Servlet(RestServlet): + """Delete a room from server asynchronously with a background task. + + It is a combination and improvement of shutdown and purge room. + + Shuts down a room by removing all local users from the room. + Blocking all future invites and joins to the room is optional. + + If desired any local aliases will be repointed to a new room + created by `new_room_user_id` and kicked users will be auto- + joined to the new room. + + If 'purge' is true, it will remove all traces of a room from the database. + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self._pagination_handler = hs.get_pagination_handler() + + async def on_DELETE( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) + + content = parse_json_object_from_request(request) + + block = content.get("block", False) + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean, if given", + Codes.BAD_JSON, + ) + + purge = content.get("purge", True) + if not isinstance(purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + force_purge = content.get("force_purge", False) + if not isinstance(force_purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'force_purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + if not RoomID.is_valid(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) + + delete_id = self._pagination_handler.start_shutdown_and_purge_room( + room_id=room_id, + new_room_user_id=content.get("new_room_user_id"), + new_room_name=content.get("room_name"), + message=content.get("message"), + requester_user_id=requester.user.to_string(), + block=block, + purge=purge, + force_purge=force_purge, + ) + + return HTTPStatus.OK, {"delete_id": delete_id} + + +class DeleteRoomStatusByRoomIdRestServlet(RestServlet): + """Get the status of the delete room background task.""" + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/delete_status$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._pagination_handler = hs.get_pagination_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + + await assert_requester_is_admin(self._auth, request) + + if not RoomID.is_valid(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) + + delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id) + if delete_ids is None: + raise NotFoundError("No delete task for room_id '%s' found" % room_id) + + response = [] + for delete_id in delete_ids: + delete = self._pagination_handler.get_delete_status(delete_id) + if delete: + response += [ + { + "delete_id": delete_id, + **delete.asdict(), + } + ] + return HTTPStatus.OK, {"results": cast(JsonDict, response)} + + +class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): + """Get the status of the delete room background task.""" + + PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]*)$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._pagination_handler = hs.get_pagination_handler() + + async def on_GET( + self, request: SynapseRequest, delete_id: str + ) -> Tuple[int, JsonDict]: + + await assert_requester_is_admin(self._auth, request) + + delete_status = self._pagination_handler.get_delete_status(delete_id) + if delete_status is None: + raise NotFoundError("delete id '%s' not found" % delete_id) + + return HTTPStatus.OK, cast(JsonDict, delete_status.asdict()) + + class ListRoomRestServlet(RestServlet): """ List all rooms that are known to the homeserver. Results are returned @@ -60,40 +193,22 @@ class ListRoomRestServlet(RestServlet): self.admin_handler = hs.get_admin_handler() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) # Extract query parameters start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) - order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value) - if order_by not in ( - RoomSortOrder.ALPHABETICAL.value, - RoomSortOrder.SIZE.value, - RoomSortOrder.NAME.value, - RoomSortOrder.CANONICAL_ALIAS.value, - RoomSortOrder.JOINED_MEMBERS.value, - RoomSortOrder.JOINED_LOCAL_MEMBERS.value, - RoomSortOrder.VERSION.value, - RoomSortOrder.CREATOR.value, - RoomSortOrder.ENCRYPTION.value, - RoomSortOrder.FEDERATABLE.value, - RoomSortOrder.PUBLIC.value, - RoomSortOrder.JOIN_RULES.value, - RoomSortOrder.GUEST_ACCESS.value, - RoomSortOrder.HISTORY_VISIBILITY.value, - RoomSortOrder.STATE_EVENTS.value, - ): - raise SynapseError( - 400, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) + order_by = parse_string( + request, + "order_by", + default=RoomSortOrder.NAME.value, + allowed_values=[sort_order.value for sort_order in RoomSortOrder], + ) search_term = parse_string(request, "search_term", encoding="utf-8") if search_term == "": raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "search_term cannot be an empty string", errcode=Codes.INVALID_PARAM, ) @@ -101,7 +216,9 @@ class ListRoomRestServlet(RestServlet): direction = parse_string(request, "dir", default="f") if direction not in ("f", "b"): raise SynapseError( - 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "Unknown direction: %s" % (direction,), + errcode=Codes.INVALID_PARAM, ) reverse_order = True if direction == "b" else False @@ -133,7 +250,7 @@ class ListRoomRestServlet(RestServlet): else: response["prev_batch"] = 0 - return 200, response + return HTTPStatus.OK, response class RoomRestServlet(RestServlet): @@ -157,10 +274,9 @@ class RoomRestServlet(RestServlet): TODO: Add on_POST to allow room creation without joining the room """ - PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$") + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.room_shutdown_handler = hs.get_room_shutdown_handler() @@ -178,7 +294,7 @@ class RoomRestServlet(RestServlet): members = await self.store.get_users_in_room(room_id) ret["joined_local_devices"] = await self.store.count_devices_by_users(members) - return 200, ret + return HTTPStatus.OK, ret async def on_DELETE( self, request: SynapseRequest, room_id: str @@ -239,9 +355,22 @@ class RoomRestServlet(RestServlet): # Purge room if purge: - await pagination_handler.purge_room(room_id, force=force_purge) - - return 200, ret + try: + await pagination_handler.purge_room(room_id, force=force_purge) + except NotFoundError: + if block: + # We can block unknown rooms with this endpoint, in which case + # a failed purge is expected. + pass + else: + # But otherwise, we expect this purge to have succeeded. + raise + + # Cast safety: cast away the knowledge that this is a TypedDict. + # See https://github.com/python/mypy/issues/4976#issuecomment-579883622 + # for some discussion on why this is necessary. Either way, + # `ret` is an opaque dictionary blob as far as the rest of the app cares. + return HTTPStatus.OK, cast(JsonDict, ret) class RoomMembersRestServlet(RestServlet): @@ -249,10 +378,9 @@ class RoomMembersRestServlet(RestServlet): Get members list of a room. """ - PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members") + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/members$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -268,7 +396,7 @@ class RoomMembersRestServlet(RestServlet): members = await self.store.get_users_in_room(room_id) ret = {"members": members, "total": len(members)} - return 200, ret + return HTTPStatus.OK, ret class RoomStateRestServlet(RestServlet): @@ -276,10 +404,9 @@ class RoomStateRestServlet(RestServlet): Get full state within a room. """ - PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state") + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/state$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -288,8 +415,7 @@ class RoomStateRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room(room_id) if not ret: @@ -298,28 +424,22 @@ class RoomStateRestServlet(RestServlet): event_ids = await self.store.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() - room_state = await self._event_serializer.serialize_events( - events.values(), - now, - # We don't bother bundling aggregations in when asked for state - # events, as clients won't use them. - bundle_aggregations=False, - ) + room_state = await self._event_serializer.serialize_events(events.values(), now) ret = {"state": room_state} - return 200, ret + return HTTPStatus.OK, ret class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") + PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() self.state_handler = hs.get_state_handler() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, room_identifier: str @@ -335,8 +455,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): assert_params_in_dict(content, ["user_id"]) target_user = UserID.from_string(content["user_id"]) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "This endpoint can only be used with local users") + if not self.is_mine(target_user): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "This endpoint can only be used with local users", + ) if not await self.admin_handler.get_user(target_user): raise NotFoundError("User not found") @@ -382,7 +505,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): ratelimit=False, ) - return 200, {"room_id": room_id} + return HTTPStatus.OK, {"room_id": room_id} class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): @@ -397,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): } """ - PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin") + PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_creation_handler = hs.get_event_creation_handler() @@ -423,7 +545,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): # Figure out which local users currently have power in the room, if any. room_state = await self.state_handler.get_current_state(room_id) if not room_state: - raise SynapseError(400, "Server not in room") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room") create_event = room_state[(EventTypes.Create, "")] power_levels = room_state.get((EventTypes.PowerLevels, "")) @@ -437,7 +559,9 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): admin_users.sort(key=lambda user: user_power[user]) if not admin_users: - raise SynapseError(400, "No local admin user in room") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "No local admin user in room" + ) admin_user_id = None @@ -454,7 +578,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): if not admin_user_id: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "No local admin user in room", ) @@ -465,7 +589,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): admin_user_id = create_event.sender if not self.is_mine_id(admin_user_id): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "No local admin user in room", ) @@ -494,7 +618,8 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): except AuthError: # The admin user we found turned out not to have enough power. raise SynapseError( - 400, "No local admin user in room with power to update power levels." + HTTPStatus.BAD_REQUEST, + "No local admin user in room with power to update power levels.", ) # Now we check if the user we're granting admin rights to is already in @@ -508,7 +633,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): ) if is_joined: - return 200, {} + return HTTPStatus.OK, {} join_rules = room_state.get((EventTypes.JoinRules, "")) is_public = False @@ -516,7 +641,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC if is_public: - return 200, {} + return HTTPStatus.OK, {} await self.room_member_handler.update_membership( fake_requester, @@ -525,7 +650,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): action=Membership.INVITE, ) - return 200, {} + return HTTPStatus.OK, {} class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): @@ -540,35 +665,32 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities """ - PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities") + PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() async def on_DELETE( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) deleted_count = await self.store.delete_forward_extremities_for_room(room_id) - return 200, {"deleted": deleted_count} + return HTTPStatus.OK, {"deleted": deleted_count} async def on_GET( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) extremities = await self.store.get_forward_extremities_for_room(room_id) - return 200, {"count": len(extremities), "results": extremities} + return HTTPStatus.OK, {"count": len(extremities), "results": extremities} class RoomEventContextServlet(RestServlet): @@ -583,6 +705,7 @@ class RoomEventContextServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() @@ -600,7 +723,9 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) + event_filter: Optional[Filter] = Filter( + self._hs, json_decoder.decode(filter_json) + ) else: event_filter = None @@ -614,7 +739,9 @@ class RoomEventContextServlet(RestServlet): ) if not results: - raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + raise SynapseError( + HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND + ) time_now = self.clock.time_msec() results["events_before"] = await self._event_serializer.serialize_events( @@ -627,10 +754,70 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_aggregations=False, + results["state"], time_now ) - return 200, results + return HTTPStatus.OK, results + + +class BlockRoomRestServlet(RestServlet): + """ + Manage blocking of rooms. + On PUT: Add or remove a room from blocking list. + On GET: Get blocking status of room and user who has blocked this room. + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/block$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + if not RoomID.is_valid(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) + + blocked_by = await self._store.room_is_blocked_by(room_id) + # Test `not None` if `user_id` is an empty string + # if someone add manually an entry in database + if blocked_by is not None: + response = {"block": True, "user_id": blocked_by} + else: + response = {"block": False} + + return HTTPStatus.OK, response + + async def on_PUT( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) + + content = parse_json_object_from_request(request) + + if not RoomID.is_valid(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) + + assert_params_in_dict(content, ["block"]) + block = content.get("block") + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean.", + Codes.BAD_JSON, + ) + + if block: + await self._store.block_room(room_id, requester.user.to_string()) + else: + await self._store.unblock_room(room_id) + + return HTTPStatus.OK, {"block": block} diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index 19f84f33f2..15da9cd881 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus from typing import TYPE_CHECKING, Awaitable, Optional, Tuple from synapse.api.constants import EventTypes @@ -51,11 +52,11 @@ class SendServerNoticeServlet(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.server_notices_manager = hs.get_server_notices_manager() self.admin_handler = hs.get_admin_handler() self.txns = HttpTransactionCache(hs) + self.is_mine = hs.is_mine def register(self, json_resource: HttpServer) -> None: PATTERN = "/send_server_notice" @@ -82,11 +83,15 @@ class SendServerNoticeServlet(RestServlet): # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the # admin api). if not self.server_notices_manager.is_enabled(): - raise SynapseError(400, "Server notices are not enabled on this server") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Server notices are not enabled on this server" + ) target_user = UserID.from_string(body["user_id"]) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Server notices can only be sent to local users") + if not self.is_mine(target_user): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users" + ) if not await self.admin_handler.get_user(target_user): raise NotFoundError("User not found") @@ -99,7 +104,7 @@ class SendServerNoticeServlet(RestServlet): txn_id=txn_id, ) - return 200, {"event_id": event.event_id} + return HTTPStatus.OK, {"event_id": event.event_id} def on_PUT( self, request: SynapseRequest, txn_id: str diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 948de94ccd..7a6546372e 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, SynapseError @@ -36,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet): PATTERNS = admin_patterns("/statistics/users/media$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -44,24 +44,21 @@ class UserMediaStatisticsRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) order_by = parse_string( - request, "order_by", default=UserSortOrder.USER_ID.value + request, + "order_by", + default=UserSortOrder.USER_ID.value, + allowed_values=( + UserSortOrder.MEDIA_LENGTH.value, + UserSortOrder.MEDIA_COUNT.value, + UserSortOrder.USER_ID.value, + UserSortOrder.DISPLAYNAME.value, + ), ) - if order_by not in ( - UserSortOrder.MEDIA_LENGTH.value, - UserSortOrder.MEDIA_COUNT.value, - UserSortOrder.USER_ID.value, - UserSortOrder.DISPLAYNAME.value, - ): - raise SynapseError( - 400, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) start = parse_integer(request, "from", default=0) if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -69,7 +66,7 @@ class UserMediaStatisticsRestServlet(RestServlet): limit = parse_integer(request, "limit", default=100) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -77,7 +74,7 @@ class UserMediaStatisticsRestServlet(RestServlet): from_ts = parse_integer(request, "from_ts", default=0) if from_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from_ts must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -86,13 +83,13 @@ class UserMediaStatisticsRestServlet(RestServlet): if until_ts is not None: if until_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter until_ts must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if until_ts <= from_ts: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter until_ts must be greater than from_ts.", errcode=Codes.INVALID_PARAM, ) @@ -100,7 +97,7 @@ class UserMediaStatisticsRestServlet(RestServlet): search_term = parse_string(request, "search_term") if search_term == "": raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter search_term cannot be an empty string.", errcode=Codes.INVALID_PARAM, ) @@ -108,7 +105,9 @@ class UserMediaStatisticsRestServlet(RestServlet): direction = parse_string(request, "dir", default="f") if direction not in ("f", "b"): raise SynapseError( - 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "Unknown direction: %s" % (direction,), + errcode=Codes.INVALID_PARAM, ) users_media, total = await self.store.get_users_media_usage_paginate( @@ -118,4 +117,4 @@ class UserMediaStatisticsRestServlet(RestServlet): if (start + limit) < total: ret["next_token"] = start + len(users_media) - return 200, ret + return HTTPStatus.OK, ret diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py index 2bf1472967..5353dc3682 100644 --- a/synapse/rest/admin/username_available.py +++ b/synapse/rest/admin/username_available.py @@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/username_available") + PATTERNS = admin_patterns("/username_available$") def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index d14fafbbc9..db678da4cf 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -79,14 +78,14 @@ class UsersRestServletV2(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -122,11 +121,11 @@ class UsersRestServletV2(RestServlet): if (start + limit) < total: ret["next_token"] = str(start + len(users)) - return 200, ret + return HTTPStatus.OK, ret class UserRestServletV2(RestServlet): - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$", "v2") """Get request to list user details. This needs user to have administrator access in Synapse. @@ -172,14 +171,14 @@ class UserRestServletV2(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") ret = await self.admin_handler.get_user(target_user) if not ret: raise NotFoundError("User not found") - return 200, ret + return HTTPStatus.OK, ret async def on_PUT( self, request: SynapseRequest, user_id: str @@ -191,7 +190,10 @@ class UserRestServletV2(RestServlet): body = parse_json_object_from_request(request) if not self.hs.is_mine(target_user): - raise SynapseError(400, "This endpoint can only be used with local users") + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "This endpoint can only be used with local users", + ) user = await self.admin_handler.get_user(target_user) user_id = target_user.to_string() @@ -210,7 +212,7 @@ class UserRestServletV2(RestServlet): user_type = body.get("user_type", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(400, "Invalid user type") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type") set_admin_to = body.get("admin", False) if not isinstance(set_admin_to, bool): @@ -223,11 +225,13 @@ class UserRestServletV2(RestServlet): password = body.get("password", None) if password is not None: if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") deactivate = body.get("deactivated", False) if not isinstance(deactivate, bool): - raise SynapseError(400, "'deactivated' parameter is not of type boolean") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean" + ) # convert List[Dict[str, str]] into List[Tuple[str, str]] if external_ids is not None: @@ -282,7 +286,9 @@ class UserRestServletV2(RestServlet): user_id, ) except ExternalIDReuseException: - raise SynapseError(409, "External id is already in use.") + raise SynapseError( + HTTPStatus.CONFLICT, "External id is already in use." + ) if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( @@ -293,7 +299,9 @@ class UserRestServletV2(RestServlet): if set_admin_to != user["admin"]: auth_user = requester.user if target_user == auth_user and not set_admin_to: - raise SynapseError(400, "You may not demote yourself.") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "You may not demote yourself." + ) await self.store.set_server_admin(target_user, set_admin_to) @@ -319,7 +327,8 @@ class UserRestServletV2(RestServlet): and self.auth_handler.can_change_password() ): raise SynapseError( - 400, "Must provide a password to re-activate an account." + HTTPStatus.BAD_REQUEST, + "Must provide a password to re-activate an account.", ) await self.deactivate_account_handler.activate_account( @@ -332,7 +341,7 @@ class UserRestServletV2(RestServlet): user = await self.admin_handler.get_user(target_user) assert user is not None - return 200, user + return HTTPStatus.OK, user else: # create user displayname = body.get("displayname", None) @@ -381,7 +390,9 @@ class UserRestServletV2(RestServlet): user_id, ) except ExternalIDReuseException: - raise SynapseError(409, "External id is already in use.") + raise SynapseError( + HTTPStatus.CONFLICT, "External id is already in use." + ) if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( @@ -402,7 +413,7 @@ class UserRegisterServlet(RestServlet): nonce to the time it was generated, in int seconds. """ - PATTERNS = admin_patterns("/register") + PATTERNS = admin_patterns("/register$") NONCE_TIMEOUT = 60 def __init__(self, hs: "HomeServer"): @@ -429,51 +440,61 @@ class UserRegisterServlet(RestServlet): nonce = secrets.token_hex(64) self.nonces[nonce] = int(self.reactor.seconds()) - return 200, {"nonce": nonce} + return HTTPStatus.OK, {"nonce": nonce} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._clear_old_nonces() if not self.hs.config.registration.registration_shared_secret: - raise SynapseError(400, "Shared secret registration is not enabled") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Shared secret registration is not enabled" + ) body = parse_json_object_from_request(request) if "nonce" not in body: - raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON) + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "nonce must be specified", + errcode=Codes.BAD_JSON, + ) nonce = body["nonce"] if nonce not in self.nonces: - raise SynapseError(400, "unrecognised nonce") + raise SynapseError(HTTPStatus.BAD_REQUEST, "unrecognised nonce") # Delete the nonce, so it can't be reused, even if it's invalid del self.nonces[nonce] if "username" not in body: raise SynapseError( - 400, "username must be specified", errcode=Codes.BAD_JSON + HTTPStatus.BAD_REQUEST, + "username must be specified", + errcode=Codes.BAD_JSON, ) else: if not isinstance(body["username"], str) or len(body["username"]) > 512: - raise SynapseError(400, "Invalid username") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username") username = body["username"].encode("utf-8") if b"\x00" in username: - raise SynapseError(400, "Invalid username") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username") if "password" not in body: raise SynapseError( - 400, "password must be specified", errcode=Codes.BAD_JSON + HTTPStatus.BAD_REQUEST, + "password must be specified", + errcode=Codes.BAD_JSON, ) else: password = body["password"] if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") password_bytes = password.encode("utf-8") if b"\x00" in password_bytes: - raise SynapseError(400, "Invalid password") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") password_hash = await self.auth_handler.hash(password) @@ -482,10 +503,12 @@ class UserRegisterServlet(RestServlet): displayname = body.get("displayname", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(400, "Invalid user type") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type") if "mac" not in body: - raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "mac must be specified", errcode=Codes.BAD_JSON + ) got_mac = body["mac"] @@ -507,7 +530,7 @@ class UserRegisterServlet(RestServlet): want_mac = want_mac_builder.hexdigest() if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")): - raise SynapseError(403, "HMAC incorrect") + raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication from synapse.rest.client.register import RegisterRestServlet @@ -524,7 +547,7 @@ class UserRegisterServlet(RestServlet): ) result = await register._create_registration_details(user_id, body) - return 200, result + return HTTPStatus.OK, result class WhoisRestServlet(RestServlet): @@ -537,9 +560,9 @@ class WhoisRestServlet(RestServlet): ] def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -551,16 +574,16 @@ class WhoisRestServlet(RestServlet): if target_user != auth_user: await assert_user_is_admin(self.auth, auth_user) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only whois a local user") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") ret = await self.admin_handler.get_whois(target_user) - return 200, ret + return HTTPStatus.OK, ret class DeactivateAccountRestServlet(RestServlet): - PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)") + PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() @@ -575,7 +598,9 @@ class DeactivateAccountRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) if not self.is_mine(UserID.from_string(target_user_id)): - raise SynapseError(400, "Can only deactivate local users") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Can only deactivate local users" + ) if not await self.store.get_user_by_id(target_user_id): raise NotFoundError("User not found") @@ -597,14 +622,13 @@ class DeactivateAccountRestServlet(RestServlet): else: id_server_unbind_result = "no-support" - return 200, {"id_server_unbind_result": id_server_unbind_result} + return HTTPStatus.OK, {"id_server_unbind_result": id_server_unbind_result} class AccountValidityRenewServlet(RestServlet): PATTERNS = admin_patterns("/account_validity/validity$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() @@ -620,7 +644,7 @@ class AccountValidityRenewServlet(RestServlet): if "user_id" not in body: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Missing property 'user_id' in the request body", ) @@ -631,7 +655,7 @@ class AccountValidityRenewServlet(RestServlet): ) res = {"expiration_ts": expiration_ts} - return 200, res + return HTTPStatus.OK, res class ResetPasswordRestServlet(RestServlet): @@ -648,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet): 200 OK with empty object if success otherwise an error. """ - PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)") + PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() @@ -678,7 +701,7 @@ class ResetPasswordRestServlet(RestServlet): await self._set_password_handler.set_password( target_user_id, new_password_hash, logout_devices, requester ) - return 200, {} + return HTTPStatus.OK, {} class SearchUsersRestServlet(RestServlet): @@ -692,12 +715,12 @@ class SearchUsersRestServlet(RestServlet): 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)") + PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, target_user_id: str @@ -712,16 +735,16 @@ class SearchUsersRestServlet(RestServlet): # To allow all users to get the users list # if not is_admin and target_user != auth_user: - # raise AuthError(403, "You are not a server admin") + # raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only users a local user") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user") term = parse_string(request, "term", required=True) logger.info("term: %s ", term) ret = await self.store.search_users(term) - return 200, ret + return HTTPStatus.OK, ret class UserAdminServlet(RestServlet): @@ -753,9 +776,9 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -764,12 +787,15 @@ class UserAdminServlet(RestServlet): target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Only local users can be admins of this homeserver") + if not self.is_mine(target_user): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Only local users can be admins of this homeserver", + ) is_admin = await self.store.is_server_admin(target_user) - return 200, {"admin": is_admin} + return HTTPStatus.OK, {"admin": is_admin} async def on_PUT( self, request: SynapseRequest, user_id: str @@ -784,17 +810,20 @@ class UserAdminServlet(RestServlet): assert_params_in_dict(body, ["admin"]) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Only local users can be admins of this homeserver") + if not self.is_mine(target_user): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Only local users can be admins of this homeserver", + ) set_admin_to = bool(body["admin"]) if target_user == auth_user and not set_admin_to: - raise SynapseError(400, "You may not demote yourself.") + raise SynapseError(HTTPStatus.BAD_REQUEST, "You may not demote yourself.") await self.store.set_server_admin(target_user, set_admin_to) - return 200, {} + return HTTPStatus.OK, {} class UserMembershipRestServlet(RestServlet): @@ -802,7 +831,7 @@ class UserMembershipRestServlet(RestServlet): Get room list of an user. """ - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/joined_rooms$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -816,7 +845,7 @@ class UserMembershipRestServlet(RestServlet): room_ids = await self.store.get_rooms_for_user(user_id) ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} - return 200, ret + return HTTPStatus.OK, ret class PushersRestServlet(RestServlet): @@ -845,7 +874,7 @@ class PushersRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -854,7 +883,10 @@ class PushersRestServlet(RestServlet): filtered_pushers = [p.as_dict() for p in pushers] - return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)} + return HTTPStatus.OK, { + "pushers": filtered_pushers, + "total": len(filtered_pushers), + } class UserTokenRestServlet(RestServlet): @@ -874,10 +906,10 @@ class UserTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str @@ -886,30 +918,36 @@ class UserTokenRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) auth_user = requester.user - if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be logged in as") + if not self.is_mine_id(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be logged in as" + ) body = parse_json_object_from_request(request, allow_empty_body=True) valid_until_ms = body.get("valid_until_ms") if valid_until_ms and not isinstance(valid_until_ms, int): - raise SynapseError(400, "'valid_until_ms' parameter must be an int") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int" + ) if auth_user.to_string() == user_id: - raise SynapseError(400, "Cannot use admin API to login as self") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Cannot use admin API to login as self" + ) - token = await self.auth_handler.get_access_token_for_user_id( + token = await self.auth_handler.create_access_token_for_user_id( user_id=auth_user.to_string(), device_id=None, valid_until_ms=valid_until_ms, puppets_user_id=user_id, ) - return 200, {"access_token": token} + return HTTPStatus.OK, {"access_token": token} class ShadowBanRestServlet(RestServlet): - """An admin API for shadow-banning a user. + """An admin API for controlling whether a user is shadow-banned. A shadow-banned users receives successful responses to their client-server API requests, but the events are not propagated into rooms. @@ -917,33 +955,57 @@ class ShadowBanRestServlet(RestServlet): Shadow-banning a user should be used as a tool of last resort and may lead to confusing or broken behaviour for the client. - Example: + Example of shadow-banning a user: POST /_synapse/admin/v1/users/@test:example.com/shadow_ban {} 200 OK {} + + Example of removing a user from being shadow-banned: + + DELETE /_synapse/admin/v1/users/@test:example.com/shadow_ban + {} + + 200 OK + {} """ - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be shadow-banned") + if not self.is_mine_id(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" + ) await self.store.set_shadow_banned(UserID.from_string(user_id), True) - return 200, {} + return HTTPStatus.OK, {} + + async def on_DELETE( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine_id(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" + ) + + await self.store.set_shadow_banned(UserID.from_string(user_id), False) + + return HTTPStatus.OK, {} class RateLimitRestServlet(RestServlet): @@ -962,20 +1024,20 @@ class RateLimitRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_GET( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Can only look up local users") + if not self.is_mine_id(user_id): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -996,15 +1058,17 @@ class RateLimitRestServlet(RestServlet): else: ret = {} - return 200, ret + return HTTPStatus.OK, ret async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be ratelimited") + if not self.is_mine_id(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" + ) if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -1016,14 +1080,14 @@ class RateLimitRestServlet(RestServlet): if not isinstance(messages_per_second, int) or messages_per_second < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (messages_per_second,), errcode=Codes.INVALID_PARAM, ) if not isinstance(burst_count, int) or burst_count < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (burst_count,), errcode=Codes.INVALID_PARAM, ) @@ -1039,19 +1103,21 @@ class RateLimitRestServlet(RestServlet): "burst_count": ratelimit.burst_count, } - return 200, ret + return HTTPStatus.OK, ret async def on_DELETE( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be ratelimited") + if not self.is_mine_id(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" + ) if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") await self.store.delete_ratelimit_for_user(user_id) - return 200, {} + return HTTPStatus.OK, {} diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py index a0971ce994..b4cb90cb76 100644 --- a/synapse/rest/client/_base.py +++ b/synapse/rest/client/_base.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) def client_patterns( path_regex: str, - releases: Iterable[int] = (0,), + releases: Iterable[str] = ("r0", "v3"), unstable: bool = True, v1: bool = False, ) -> Iterable[Pattern]: @@ -52,7 +52,7 @@ def client_patterns( v1_prefix = CLIENT_API_PREFIX + "/api/v1" patterns.append(re.compile("^" + v1_prefix + path_regex)) for release in releases: - new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) + new_prefix = CLIENT_API_PREFIX + f"/{release}" patterns.append(re.compile("^" + new_prefix + path_regex)) return patterns diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8566dc5cb5..ad6fd6492b 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -17,6 +17,7 @@ import logging from typing import TYPE_CHECKING, Tuple from synapse.api import errors +from synapse.api.errors import NotFoundError from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -24,10 +25,9 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.types import JsonDict -from ._base import client_patterns, interactive_auth_handler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -116,6 +116,8 @@ class DeviceRestServlet(RestServlet): device = await self.device_handler.get_device( requester.user.to_string(), device_id ) + if device is None: + raise NotFoundError("No device found") return 200, device @interactive_auth_handler diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 7281b2ee29..730c18f08f 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -262,7 +262,7 @@ class SigningKeyUploadServlet(RestServlet): } """ - PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) + PATTERNS = client_patterns("/keys/device_signing/upload$", releases=("v3",)) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index d49a647b03..f9994658c4 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -14,7 +14,17 @@ import logging import re -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, + Union, +) from typing_extensions import TypedDict @@ -28,7 +38,6 @@ from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, assert_params_in_dict, - parse_boolean, parse_bytes_from_args, parse_json_object_from_request, parse_string, @@ -61,8 +70,9 @@ class LoginRestServlet(RestServlet): TOKEN_TYPE = "m.login.token" JWT_TYPE = "org.matrix.login.jwt" JWT_TYPE_DEPRECATED = "m.login.jwt" - APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" - REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" + APPSERVICE_TYPE = "m.login.application_service" + APPSERVICE_TYPE_UNSTABLE = "uk.half-shot.msc2778.login.application_service" + REFRESH_TOKEN_PARAM = "refresh_token" def __init__(self, hs: "HomeServer"): super().__init__() @@ -71,6 +81,7 @@ class LoginRestServlet(RestServlet): # JWT configuration variables. self.jwt_enabled = hs.config.jwt.jwt_enabled self.jwt_secret = hs.config.jwt.jwt_secret + self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim self.jwt_algorithm = hs.config.jwt.jwt_algorithm self.jwt_issuer = hs.config.jwt.jwt_issuer self.jwt_audiences = hs.config.jwt.jwt_audiences @@ -79,7 +90,9 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None + self._refresh_tokens_enabled = ( + hs.config.registration.refreshable_access_token_lifetime is not None + ) self.auth = hs.get_auth() @@ -143,23 +156,29 @@ class LoginRestServlet(RestServlet): flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) + flows.append({"type": LoginRestServlet.APPSERVICE_TYPE_UNSTABLE}) return 200, {"flows": flows} async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]: login_submission = parse_json_object_from_request(request) - if self._msc2918_enabled: - # Check if this login should also issue a refresh token, as per - # MSC2918 - should_issue_refresh_token = parse_boolean( - request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False - ) - else: - should_issue_refresh_token = False + # Check to see if the client requested a refresh token. + client_requested_refresh_token = login_submission.get( + LoginRestServlet.REFRESH_TOKEN_PARAM, False + ) + if not isinstance(client_requested_refresh_token, bool): + raise SynapseError(400, "`refresh_token` should be true or false.") + + should_issue_refresh_token = ( + self._refresh_tokens_enabled and client_requested_refresh_token + ) try: - if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: + if login_submission["type"] in ( + LoginRestServlet.APPSERVICE_TYPE, + LoginRestServlet.APPSERVICE_TYPE_UNSTABLE, + ): appservice = self.auth.get_appservice_by_req(request) if appservice.is_rate_limited(): @@ -283,6 +302,7 @@ class LoginRestServlet(RestServlet): ratelimit: bool = True, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> LoginResponse: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -298,10 +318,10 @@ class LoginRestServlet(RestServlet): create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. ratelimit: Whether to ratelimit the login request. - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). + auth_provider_id: The SSO IdP the user used, if any. should_issue_refresh_token: True if this login should issue a refresh token alongside the access token. + auth_provider_session_id: The session ID got during login from the SSO IdP. Returns: result: Dictionary of account information after successful login. @@ -334,6 +354,7 @@ class LoginRestServlet(RestServlet): initial_display_name, auth_provider_id=auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_session_id=auth_provider_session_id, ) result = LoginResponse( @@ -379,6 +400,7 @@ class LoginRestServlet(RestServlet): self.auth_handler._sso_login_callback, auth_provider_id=res.auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_session_id=res.auth_provider_session_id, ) async def _do_jwt_login( @@ -408,7 +430,7 @@ class LoginRestServlet(RestServlet): errcode=Codes.FORBIDDEN, ) - user = payload.get("sub", None) + user = payload.get(self.jwt_subject_claim, None) if user is None: raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) @@ -440,14 +462,15 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: class RefreshTokenServlet(RestServlet): - PATTERNS = client_patterns( - "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True - ) + PATTERNS = (re.compile("^/_matrix/client/v1/refresh$"),) def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() self._clock = hs.get_clock() - self.access_token_lifetime = hs.config.registration.access_token_lifetime + self.refreshable_access_token_lifetime = ( + hs.config.registration.refreshable_access_token_lifetime + ) + self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) @@ -457,27 +480,40 @@ class RefreshTokenServlet(RestServlet): if not isinstance(token, str): raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) - valid_until_ms = self._clock.time_msec() + self.access_token_lifetime - access_token, refresh_token = await self._auth_handler.refresh_token( - token, valid_until_ms - ) - expires_in_ms = valid_until_ms - self._clock.time_msec() - return ( - 200, - { - "access_token": access_token, - "refresh_token": refresh_token, - "expires_in_ms": expires_in_ms, - }, + now = self._clock.time_msec() + access_valid_until_ms = None + if self.refreshable_access_token_lifetime is not None: + access_valid_until_ms = now + self.refreshable_access_token_lifetime + refresh_valid_until_ms = None + if self.refresh_token_lifetime is not None: + refresh_valid_until_ms = now + self.refresh_token_lifetime + + ( + access_token, + refresh_token, + actual_access_token_expiry, + ) = await self._auth_handler.refresh_token( + token, access_valid_until_ms, refresh_valid_until_ms ) + response: Dict[str, Union[str, int]] = { + "access_token": access_token, + "refresh_token": refresh_token, + } + + # expires_in_ms is only present if the token expires + if actual_access_token_expiry is not None: + response["expires_in_ms"] = actual_access_token_expiry - now + + return 200, response + class SsoRedirectServlet(RestServlet): PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ re.compile( "^" + CLIENT_API_PREFIX - + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$" + + "/(r0|v3)/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$" ) ] @@ -556,7 +592,7 @@ class CasTicketServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) - if hs.config.registration.access_token_lifetime is not None: + if hs.config.registration.refreshable_access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas.cas_enabled: diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index d1d8a984c6..b12a332776 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import ReceiptTypes from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -54,7 +55,7 @@ class NotificationsServlet(RestServlet): ) receipts_by_room = await self.store.get_receipts_for_user_with_orderings( - user_id, "m.read" + user_id, ReceiptTypes.READ ) notif_event_ids = [pa["event_id"] for pa in push_actions] diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 43c04fac6f..f51be511d1 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet): await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) - read_event_id = body.get("m.read", None) + read_event_id = body.get(ReceiptTypes.READ, None) hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) if not isinstance(hidden, bool): @@ -62,7 +62,7 @@ class ReadMarkerRestServlet(RestServlet): if read_event_id: await self.receipts_handler.received_client_receipt( room_id, - "m.read", + ReceiptTypes.READ, user_id=requester.user.to_string(), event_id=read_event_id, hidden=hidden, diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 9770413c61..b24ad2d1be 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -13,10 +13,12 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError +from synapse.http import get_request_user_agent from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -24,6 +26,8 @@ from synapse.types import JsonDict from ._base import client_patterns +pattern = re.compile(r"(?:Element|SchildiChat)/1\.[012]\.") + if TYPE_CHECKING: from synapse.server import HomeServer @@ -49,10 +53,16 @@ class ReceiptRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - if receipt_type != "m.read": + if receipt_type != ReceiptTypes.READ: raise SynapseError(400, "Receipt type must be 'm.read'") - body = parse_json_object_from_request(request, allow_empty_body=True) + # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body. + user_agent = get_request_user_agent(request) + allow_empty_body = False + if "Android" in user_agent: + if pattern.match(user_agent) or "Riot" in user_agent: + allow_empty_body = True + body = parse_json_object_from_request(request, allow_empty_body) hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) if not isinstance(hidden, bool): diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index bf3cb34146..8b56c76aed 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -41,7 +41,6 @@ from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, - parse_boolean, parse_json_object_from_request, parse_string, ) @@ -420,7 +419,9 @@ class RegisterRestServlet(RestServlet): self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_enabled = self.hs.config.registration.enable_registration - self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None + self._refresh_tokens_enabled = ( + hs.config.registration.refreshable_access_token_lifetime is not None + ) self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -444,14 +445,15 @@ class RegisterRestServlet(RestServlet): f"Do not understand membership kind: {kind}", ) - if self._msc2918_enabled: - # Check if this registration should also issue a refresh token, as - # per MSC2918 - should_issue_refresh_token = parse_boolean( - request, name="org.matrix.msc2918.refresh_token", default=False - ) - else: - should_issue_refresh_token = False + # Check if the clients wishes for this registration to issue a refresh + # token. + client_requested_refresh_tokens = body.get("refresh_token", False) + if not isinstance(client_requested_refresh_tokens, bool): + raise SynapseError(400, "`refresh_token` should be true or false.") + + should_issue_refresh_token = ( + self._refresh_tokens_enabled and client_requested_refresh_tokens + ) # Pull out the provided username and do basic sanity checks early since # the auth layer will store these in sessions. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 58f6699073..ffa37ef06c 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet): pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, limit=limit, @@ -224,18 +225,14 @@ class RelationPaginationServlet(RestServlet): ) now = self.clock.time_msec() - # We set bundle_aggregations to False when retrieving the original - # event because we want the content before relations were applied to - # it. + # Do not bundle aggregations when retrieving the original event because + # we want the content before relations are applied to it. original_event = await self._event_serializer.serialize_event( event, now, bundle_aggregations=False ) - # Similarly, we don't allow relations to be applied to relations, so we - # return the original relations without any aggregations on top of them - # here. - serialized_events = await self._event_serializer.serialize_events( - events, now, bundle_aggregations=False - ) + # The relations returned for the requested event do include their + # bundled aggregations. + serialized_events = await self._event_serializer.serialize_events(events, now) return_value = pagination_chunk.to_dict() return_value["chunk"] = serialized_events @@ -298,7 +295,9 @@ class RelationAggregationPaginationServlet(RestServlet): raise SynapseError(404, "Unknown parent event.") if relation_type not in (RelationTypes.ANNOTATION, None): - raise SynapseError(400, "Relation type must be 'annotation'") + raise SynapseError( + 400, f"Relation type must be '{RelationTypes.ANNOTATION}'" + ) limit = parse_integer(request, "limit", default=5) from_token_str = parse_string(request, "from") @@ -319,6 +318,7 @@ class RelationAggregationPaginationServlet(RestServlet): pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, + room_id=room_id, event_type=event_type, limit=limit, from_token=from_token, @@ -385,7 +385,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet): # This checks that a) the event exists and b) the user is allowed to # view it. - await self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -404,6 +406,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): result = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, aggregation_key=key, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 6a876cfa2f..60719331b6 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -187,7 +187,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): state_key: str, txn_id: Optional[str] = None, ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request, allow_guest=True) if txn_id: set_tag("txn_id", txn_id) @@ -550,6 +550,7 @@ class RoomMessageListRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -567,7 +568,9 @@ class RoomMessageListRestServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) + event_filter: Optional[Filter] = Filter( + self._hs, json_decoder.decode(filter_json) + ) if ( event_filter and event_filter.filter_json.get("event_format", "client") @@ -672,6 +675,7 @@ class RoomEventContextServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() @@ -688,7 +692,9 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) + event_filter: Optional[Filter] = Filter( + self._hs, json_decoder.decode(filter_json) + ) else: event_filter = None @@ -710,10 +716,7 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_aggregations=False, + results["state"], time_now ) return 200, results @@ -1064,6 +1067,62 @@ def register_txn_path( ) +class TimestampLookupRestServlet(RestServlet): + """ + API endpoint to fetch the `event_id` of the closest event to the given + timestamp (`ts` query parameter) in the given direction (`dir` query + parameter). + + Useful for cases like jump to date so you can start paginating messages from + a given date in the archive. + + `ts` is a timestamp in milliseconds where we will find the closest event in + the given direction. + + `dir` can be `f` or `b` to indicate forwards and backwards in time from the + given timestamp. + + GET /_matrix/client/unstable/org.matrix.msc3030/rooms/<roomID>/timestamp_to_event?ts=<timestamp>&dir=<direction> + { + "event_id": ... + } + """ + + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc3030" + "/rooms/(?P<room_id>[^/]*)/timestamp_to_event$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request) + await self._auth.check_user_in_room(room_id, requester.user.to_string()) + + timestamp = parse_integer(request, "ts", required=True) + direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) + + ( + event_id, + origin_server_ts, + ) = await self.timestamp_lookup_handler.get_event_for_timestamp( + requester, room_id, timestamp, direction + ) + + return 200, { + "event_id": event_id, + "origin_server_ts": origin_server_ts, + } + + class RoomSpaceSummaryRestServlet(RestServlet): PATTERNS = ( re.compile( @@ -1134,7 +1193,7 @@ class RoomSpaceSummaryRestServlet(RestServlet): class RoomHierarchyRestServlet(RestServlet): PATTERNS = ( re.compile( - "^/_matrix/client/unstable/org.matrix.msc2946" + "^/_matrix/client/(v1|unstable/org.matrix.msc2946)" "/rooms/(?P<room_id>[^/]*)/hierarchy$" ), ) @@ -1162,7 +1221,7 @@ class RoomHierarchyRestServlet(RestServlet): ) return 200, await self._room_summary_handler.get_room_hierarchy( - requester.user.to_string(), + requester, room_id, suggested_only=parse_boolean(request, "suggested_only", default=False), max_depth=max_depth, @@ -1233,6 +1292,8 @@ def register_servlets( RoomAliasListServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server) + if hs.config.experimental.msc3030_enabled: + TimestampLookupRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if not is_worker: diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 46f033eee2..e4c9451ae0 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -112,7 +112,7 @@ class RoomBatchSendEventRestServlet(RestServlet): # and have the batch connected. if batch_id_from_query: corresponding_insertion_event_id = ( - await self.store.get_insertion_event_by_batch_id( + await self.store.get_insertion_event_id_by_batch_id( room_id, batch_id_from_query ) ) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 913216a7c4..dd90ffa123 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -29,7 +29,7 @@ from typing import ( from synapse.api.constants import Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError -from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection +from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.events.utils import ( @@ -150,7 +150,7 @@ class SyncRestServlet(RestServlet): request_key = (user, timeout, since, filter_id, full_state, device_id) if filter_id is None: - filter_collection = DEFAULT_FILTER_COLLECTION + filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION elif filter_id.startswith("{"): try: filter_object = json_decoder.decode(filter_id) @@ -160,7 +160,7 @@ class SyncRestServlet(RestServlet): except Exception: raise SynapseError(400, "Invalid filter JSON") self.filtering.check_valid_filter(filter_object) - filter_collection = FilterCollection(filter_object) + filter_collection = FilterCollection(self.hs, filter_object) else: try: filter_collection = await self.filtering.get_user_filter( @@ -293,6 +293,9 @@ class SyncRestServlet(RestServlet): response[ "org.matrix.msc2732.device_unused_fallback_key_types" ] = sync_result.device_unused_fallback_key_types + response[ + "device_unused_fallback_key_types" + ] = sync_result.device_unused_fallback_key_types if joined: response["rooms"][Membership.JOIN] = joined @@ -520,9 +523,9 @@ class SyncRestServlet(RestServlet): return self._event_serializer.serialize_events( events, time_now=time_now, - # We don't bundle "live" events, as otherwise clients - # will end up double counting annotations. - bundle_aggregations=False, + # Don't bother to bundle aggregations if the timeline is unlimited, + # as clients will have all the necessary information. + bundle_aggregations=room.timeline.limited, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 8d888f4565..2290c57c12 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -93,6 +93,10 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, # Supports receiving hidden read receipts as per MSC2285 "org.matrix.msc2285": self.config.experimental.msc2285_enabled, + # Adds support for importing historical messages as per MSC2716 + "org.matrix.msc2716": self.config.experimental.msc2716_enabled, + # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030 + "org.matrix.msc3030": self.config.experimental.msc3030_enabled, }, }, ) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 12b3ae120c..b9bfbea21b 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from canonicaljson import encode_canonical_json from signedjson.sign import sign_json @@ -99,7 +99,7 @@ class LocalKey(Resource): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: Request) -> int: + def render_GET(self, request: Request) -> Optional[int]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 014fa893d6..9b40fd8a6c 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -29,7 +29,7 @@ from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.util.stringutils import is_ascii +from synapse.util.stringutils import is_ascii, parse_and_validate_server_name logger = logging.getLogger(__name__) @@ -51,6 +51,19 @@ TEXT_CONTENT_TYPES = [ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: + """Parses the server name, media ID and optional file name from the request URI + + Also performs some rough validation on the server name. + + Args: + request: The `Request`. + + Returns: + A tuple containing the parsed server name, media ID and optional file name. + + Raises: + SynapseError(404): if parsing or validation fail for any reason + """ try: # The type on postpath seems incorrect in Twisted 21.2.0. postpath: List[bytes] = request.postpath # type: ignore @@ -62,6 +75,9 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: server_name = server_name_bytes.decode("utf-8") media_id = media_id_bytes.decode("utf8") + # Validate the server name, raising if invalid + parse_and_validate_server_name(server_name) + file_name = None if len(postpath) > 2: try: diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index bec77088ee..1f6441c412 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -16,7 +16,8 @@ import functools import os import re -from typing import Any, Callable, List, TypeVar, cast +import string +from typing import Any, Callable, List, TypeVar, Union, cast NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") @@ -37,6 +38,113 @@ def _wrap_in_base_path(func: F) -> F: return cast(F, _wrapped) +GetPathMethod = TypeVar( + "GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]] +) + + +def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]: + """Wraps a path-returning method to check that the returned path(s) do not escape + the media store directory. + + The path-returning method may return either a single path, or a list of paths. + + The check is not expected to ever fail, unless `func` is missing a call to + `_validate_path_component`, or `_validate_path_component` is buggy. + + Args: + relative: A boolean indicating whether the wrapped method returns paths relative + to the media store directory. + + Returns: + A method which will wrap a path-returning method, adding a check to ensure that + the returned path(s) lie within the media store directory. The check will raise + a `ValueError` if it fails. + """ + + def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod: + @functools.wraps(func) + def _wrapped( + self: "MediaFilePaths", *args: Any, **kwargs: Any + ) -> Union[str, List[str]]: + path_or_paths = func(self, *args, **kwargs) + + if isinstance(path_or_paths, list): + paths_to_check = path_or_paths + else: + paths_to_check = [path_or_paths] + + for path in paths_to_check: + # Construct the path that will ultimately be used. + # We cannot guess whether `path` is relative to the media store + # directory, since the media store directory may itself be a relative + # path. + if relative: + path = os.path.join(self.base_path, path) + normalized_path = os.path.normpath(path) + + # Now that `normpath` has eliminated `../`s and `./`s from the path, + # `os.path.commonpath` can be used to check whether it lies within the + # media store directory. + if ( + os.path.commonpath([normalized_path, self.normalized_base_path]) + != self.normalized_base_path + ): + # The path resolves to outside the media store directory, + # or `self.base_path` is `.`, which is an unlikely configuration. + raise ValueError(f"Invalid media store path: {path!r}") + + # Note that `os.path.normpath`/`abspath` has a subtle caveat: + # `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a + # different path if `a/b/c` is a symlink. That is, the check above is + # not perfect and may allow a certain restricted subset of untrustworthy + # paths through. Since the check above is secondary to the main + # `_validate_path_component` checks, it's less important for it to be + # perfect. + # + # As an alternative, `os.path.realpath` will resolve symlinks, but + # proves problematic if there are symlinks inside the media store. + # eg. if `url_store/` is symlinked to elsewhere, its canonical path + # won't match that of the main media store directory. + + return path_or_paths + + return cast(GetPathMethod, _wrapped) + + return _wrap_with_jail_check_inner + + +ALLOWED_CHARACTERS = set( + string.ascii_letters + + string.digits + + "_-" + + ".[]:" # Domain names, IPv6 addresses and ports in server names +) +FORBIDDEN_NAMES = { + "", + os.path.curdir, # "." for the current platform + os.path.pardir, # ".." for the current platform +} + + +def _validate_path_component(name: str) -> str: + """Checks that the given string can be safely used as a path component + + Args: + name: The path component to check. + + Returns: + The path component if valid. + + Raises: + ValueError: If `name` cannot be safely used as a path component. + """ + if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES: + raise ValueError(f"Invalid path component: {name!r}") + + return name + + class MediaFilePaths: """Describes where files are stored on disk. @@ -47,23 +155,45 @@ class MediaFilePaths: def __init__(self, primary_base_path: str): self.base_path = primary_base_path - + self.normalized_base_path = os.path.normpath(self.base_path) + + # Refuse to initialize if paths cannot be validated correctly for the current + # platform. + assert os.path.sep not in ALLOWED_CHARACTERS + assert os.path.altsep not in ALLOWED_CHARACTERS + # On Windows, paths have all sorts of weirdness which `_validate_path_component` + # does not consider. In any case, the remote media store can't work correctly + # for certain homeservers there, since ":"s aren't allowed in paths. + assert os.name == "posix" + + @_wrap_with_jail_check(relative=True) def local_media_filepath_rel(self, media_id: str) -> str: - return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) + return os.path.join( + "local_content", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) + @_wrap_with_jail_check(relative=True) def local_media_thumbnail_rel( self, media_id: str, width: int, height: int, content_type: str, method: str ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( - "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name + "local_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + _validate_path_component(file_name), ) local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) + @_wrap_with_jail_check(relative=False) def local_media_thumbnail_dir(self, media_id: str) -> str: """ Retrieve the local store path of thumbnails of a given media_id @@ -76,18 +206,24 @@ class MediaFilePaths: return os.path.join( self.base_path, "local_thumbnails", - media_id[0:2], - media_id[2:4], - media_id[4:], + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), ) + @_wrap_with_jail_check(relative=True) def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: return os.path.join( - "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] + "remote_content", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), ) remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) + @_wrap_with_jail_check(relative=True) def remote_media_thumbnail_rel( self, server_name: str, @@ -101,11 +237,11 @@ class MediaFilePaths: file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( "remote_thumbnail", - server_name, - file_id[0:2], - file_id[2:4], - file_id[4:], - file_name, + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + _validate_path_component(file_name), ) remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) @@ -113,6 +249,7 @@ class MediaFilePaths: # Legacy path that was used to store thumbnails previously. # Should be removed after some time, when most of the thumbnails are stored # using the new path. + @_wrap_with_jail_check(relative=True) def remote_media_thumbnail_rel_legacy( self, server_name: str, file_id: str, width: int, height: int, content_type: str ) -> str: @@ -120,43 +257,67 @@ class MediaFilePaths: file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) return os.path.join( "remote_thumbnail", - server_name, - file_id[0:2], - file_id[2:4], - file_id[4:], - file_name, + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + _validate_path_component(file_name), ) + @_wrap_with_jail_check(relative=False) def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: return os.path.join( self.base_path, "remote_thumbnail", - server_name, - file_id[0:2], - file_id[2:4], - file_id[4:], + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), ) + @_wrap_with_jail_check(relative=True) def url_cache_filepath_rel(self, media_id: str) -> str: if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form <DATE><RANDOM_STRING> # E.g.: 2017-09-28-fsdRDt24DS234dsf - return os.path.join("url_cache", media_id[:10], media_id[11:]) + return os.path.join( + "url_cache", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ) else: - return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:]) + return os.path.join( + "url_cache", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) + @_wrap_with_jail_check(relative=False) def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): - return [os.path.join(self.base_path, "url_cache", media_id[:10])] + return [ + os.path.join( + self.base_path, "url_cache", _validate_path_component(media_id[:10]) + ) + ] else: return [ - os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]), - os.path.join(self.base_path, "url_cache", media_id[0:2]), + os.path.join( + self.base_path, + "url_cache", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + ), + os.path.join( + self.base_path, "url_cache", _validate_path_component(media_id[0:2]) + ), ] + @_wrap_with_jail_check(relative=True) def url_cache_thumbnail_rel( self, media_id: str, width: int, height: int, content_type: str, method: str ) -> str: @@ -168,37 +329,46 @@ class MediaFilePaths: if NEW_FORMAT_ID_RE.match(media_id): return os.path.join( - "url_cache_thumbnails", media_id[:10], media_id[11:], file_name + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + _validate_path_component(file_name), ) else: return os.path.join( "url_cache_thumbnails", - media_id[0:2], - media_id[2:4], - media_id[4:], - file_name, + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + _validate_path_component(file_name), ) url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) + @_wrap_with_jail_check(relative=True) def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: # Media id is of the form <DATE><RANDOM_STRING> # E.g.: 2017-09-28-fsdRDt24DS234dsf if NEW_FORMAT_ID_RE.match(media_id): - return os.path.join("url_cache_thumbnails", media_id[:10], media_id[11:]) + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ) else: return os.path.join( "url_cache_thumbnails", - media_id[0:2], - media_id[2:4], - media_id[4:], + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), ) url_cache_thumbnail_directory = _wrap_in_base_path( url_cache_thumbnail_directory_rel ) + @_wrap_with_jail_check(relative=False) def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id thumbnails" # Media id is of the form <DATE><RANDOM_STRING> @@ -206,21 +376,35 @@ class MediaFilePaths: if NEW_FORMAT_ID_RE.match(media_id): return [ os.path.join( - self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:] + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), ), - os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]), ] else: return [ os.path.join( self.base_path, "url_cache_thumbnails", - media_id[0:2], - media_id[2:4], - media_id[4:], + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), ), os.path.join( - self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4] + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), ), - os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]), ] diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index 2a59552c20..cce1527ed9 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, List, Optional import attr +from synapse.rest.media.v1.preview_html import parse_html_description from synapse.types import JsonDict from synapse.util import json_decoder @@ -245,8 +246,6 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> if video_urls: open_graph_response["og:video"] = video_urls[0] - from synapse.rest.media.v1.preview_url_resource import _calc_description - - description = _calc_description(tree) + description = parse_html_description(tree) if description: open_graph_response["og:description"] = description diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py new file mode 100644 index 0000000000..30b067dd42 --- /dev/null +++ b/synapse/rest/media/v1/preview_html.py @@ -0,0 +1,397 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import codecs +import itertools +import logging +import re +from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union +from urllib import parse as urlparse + +if TYPE_CHECKING: + from lxml import etree + +logger = logging.getLogger(__name__) + +_charset_match = re.compile( + br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I +) +_xml_encoding_match = re.compile( + br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I +) +_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) + + +def _normalise_encoding(encoding: str) -> Optional[str]: + """Use the Python codec's name as the normalised entry.""" + try: + return codecs.lookup(encoding).name + except LookupError: + return None + + +def _get_html_media_encodings( + body: bytes, content_type: Optional[str] +) -> Iterable[str]: + """ + Get potential encoding of the body based on the (presumably) HTML body or the content-type header. + + The precedence used for finding a character encoding is: + + 1. <meta> tag with a charset declared. + 2. The XML document's character encoding attribute. + 3. The Content-Type header. + 4. Fallback to utf-8. + 5. Fallback to windows-1252. + + This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector. + + Args: + body: The HTML document, as bytes. + content_type: The Content-Type header. + + Returns: + The character encoding of the body, as a string. + """ + # There's no point in returning an encoding more than once. + attempted_encodings: Set[str] = set() + + # Limit searches to the first 1kb, since it ought to be at the top. + body_start = body[:1024] + + # Check if it has an encoding set in a meta tag. + match = _charset_match.search(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding: + attempted_encodings.add(encoding) + yield encoding + + # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/> + + # Check if it has an XML document with an encoding. + match = _xml_encoding_match.match(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Check the HTTP Content-Type header for a character set. + if content_type: + content_match = _content_type_match.match(content_type) + if content_match: + encoding = _normalise_encoding(content_match.group(1)) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Finally, fallback to UTF-8, then windows-1252. + for fallback in ("utf-8", "cp1252"): + if fallback not in attempted_encodings: + yield fallback + + +def decode_body( + body: bytes, uri: str, content_type: Optional[str] = None +) -> Optional["etree.Element"]: + """ + This uses lxml to parse the HTML document. + + Args: + body: The HTML document, as bytes. + uri: The URI used to download the body. + content_type: The Content-Type header. + + Returns: + The parsed HTML body, or None if an error occurred during processed. + """ + # If there's no body, nothing useful is going to be found. + if not body: + return None + + # The idea here is that multiple encodings are tried until one works. + # Unfortunately the result is never used and then LXML will decode the string + # again with the found encoding. + for encoding in _get_html_media_encodings(body, content_type): + try: + body.decode(encoding) + except Exception: + pass + else: + break + else: + logger.warning("Unable to decode HTML body for %s", uri) + return None + + from lxml import etree + + # Create an HTML parser. + parser = etree.HTMLParser(recover=True, encoding=encoding) + + # Attempt to parse the body. Returns None if the body was successfully + # parsed, but no tree was found. + return etree.fromstring(body, parser) + + +def parse_html_to_open_graph( + tree: "etree.Element", media_uri: str +) -> Dict[str, Optional[str]]: + """ + Parse the HTML document into an Open Graph response. + + This uses lxml to search the HTML document for Open Graph data (or + synthesizes it from the document). + + Args: + tree: The parsed HTML document. + media_url: The URI used to download the body. + + Returns: + The Open Graph response as a dictionary. + """ + + # if we see any image URLs in the OG response, then spider them + # (although the client could choose to do this by asking for previews of those + # URLs to avoid DoSing the server) + + # "og:type" : "video", + # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", + # "og:site_name" : "YouTube", + # "og:video:type" : "application/x-shockwave-flash", + # "og:description" : "Fun stuff happening here", + # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", + # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", + # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + # "og:video:width" : "1280" + # "og:video:height" : "720", + # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", + + og: Dict[str, Optional[str]] = {} + for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): + if "content" in tag.attrib: + # if we've got more than 50 tags, someone is taking the piss + if len(og) >= 50: + logger.warning("Skipping OG for page with too many 'og:' tags") + return {} + og[tag.attrib["property"]] = tag.attrib["content"] + + # TODO: grab article: meta tags too, e.g.: + + # "article:publisher" : "https://www.facebook.com/thethudonline" /> + # "article:author" content="https://www.facebook.com/thethudonline" /> + # "article:tag" content="baby" /> + # "article:section" content="Breaking News" /> + # "article:published_time" content="2016-03-31T19:58:24+00:00" /> + # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> + + if "og:title" not in og: + # do some basic spidering of the HTML + title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") + if title and title[0].text is not None: + og["og:title"] = title[0].text.strip() + else: + og["og:title"] = None + + if "og:image" not in og: + # TODO: extract a favicon failing all else + meta_image = tree.xpath( + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" + ) + if meta_image: + og["og:image"] = rebase_url(meta_image[0], media_uri) + else: + # TODO: consider inlined CSS styles as well as width & height attribs + images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") + images = sorted( + images, + key=lambda i: ( + -1 * float(i.attrib["width"]) * float(i.attrib["height"]) + ), + ) + if not images: + images = tree.xpath("//img[@src]") + if images: + og["og:image"] = images[0].attrib["src"] + + if "og:description" not in og: + meta_description = tree.xpath( + "//*/meta" + "[translate(@name, 'DESCRIPTION', 'description')='description']" + "/@content" + ) + if meta_description: + og["og:description"] = meta_description[0] + else: + og["og:description"] = parse_html_description(tree) + elif og["og:description"]: + # This must be a non-empty string at this point. + assert isinstance(og["og:description"], str) + og["og:description"] = summarize_paragraphs([og["og:description"]]) + + # TODO: delete the url downloads to stop diskfilling, + # as we only ever cared about its OG + return og + + +def parse_html_description(tree: "etree.Element") -> Optional[str]: + """ + Calculate a text description based on an HTML document. + + Grabs any text nodes which are inside the <body/> tag, unless they are within + an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or + if they are within a <script/> or <style/> tag. + + This is a very very very coarse approximation to a plain text render of the page. + + Args: + tree: The parsed HTML document. + + Returns: + The plain text description, or None if one cannot be generated. + """ + # We don't just use XPATH here as that is slow on some machines. + + from lxml import etree + + TAGS_TO_REMOVE = ( + "header", + "nav", + "aside", + "footer", + "script", + "noscript", + "style", + etree.Comment, + ) + + # Split all the text nodes into paragraphs (by splitting on new + # lines) + text_nodes = ( + re.sub(r"\s+", "\n", el).strip() + for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) + ) + return summarize_paragraphs(text_nodes) + + +def _iterate_over_text( + tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] +) -> Generator[str, None, None]: + """Iterate over the tree returning text nodes in a depth first fashion, + skipping text nodes inside certain tags. + """ + # This is basically a stack that we extend using itertools.chain. + # This will either consist of an element to iterate over *or* a string + # to be returned. + elements = iter([tree]) + while True: + el = next(elements, None) + if el is None: + return + + if isinstance(el, str): + yield el + elif el.tag not in tags_to_ignore: + # el.text is the text before the first child, so we can immediately + # return it if the text exists. + if el.text: + yield el.text + + # We add to the stack all the elements children, interspersed with + # each child's tail text (if it exists). The tail text of a node + # is text that comes *after* the node, so we always include it even + # if we ignore the child node. + elements = itertools.chain( + itertools.chain.from_iterable( # Basically a flatmap + [child, child.tail] if child.tail else [child] + for child in el.iterchildren() + ), + elements, + ) + + +def rebase_url(url: str, base: str) -> str: + base_parts = list(urlparse.urlparse(base)) + url_parts = list(urlparse.urlparse(url)) + if not url_parts[0]: # fix up schema + url_parts[0] = base_parts[0] or "http" + if not url_parts[1]: # fix up hostname + url_parts[1] = base_parts[1] + if not url_parts[2].startswith("/"): + url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2] + return urlparse.urlunparse(url_parts) + + +def summarize_paragraphs( + text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 +) -> Optional[str]: + """ + Try to get a summary respecting first paragraph and then word boundaries. + + Args: + text_nodes: The paragraphs to summarize. + min_size: The minimum number of words to include. + max_size: The maximum number of words to include. + + Returns: + A summary of the text nodes, or None if that was not possible. + """ + + # TODO: Respect sentences? + + description = "" + + # Keep adding paragraphs until we get to the MIN_SIZE. + for text_node in text_nodes: + if len(description) < min_size: + text_node = re.sub(r"[\t \r\n]+", " ", text_node) + description += text_node + "\n\n" + else: + break + + description = description.strip() + description = re.sub(r"[\t ]+", " ", description) + description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description) + + # If the concatenation of paragraphs to get above MIN_SIZE + # took us over MAX_SIZE, then we need to truncate mid paragraph + if len(description) > max_size: + new_desc = "" + + # This splits the paragraph into words, but keeping the + # (preceding) whitespace intact so we can easily concat + # words back together. + for match in re.finditer(r"\s*\S+", description): + word = match.group() + + # Keep adding words while the total length is less than + # MAX_SIZE. + if len(word) + len(new_desc) < max_size: + new_desc += word + else: + # At this point the next word *will* take us over + # MAX_SIZE, but we also want to ensure that its not + # a huge word. If it is add it anyway and we'll + # truncate later. + if len(new_desc) < min_size: + new_desc += word + break + + # Double check that we're not over the limit + if len(new_desc) > max_size: + new_desc = new_desc[:max_size] + + # We always add an ellipsis because at the very least + # we chopped mid paragraph. + description = new_desc.strip() + "…" + return description if description else None diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 8ca97b5b18..a3829d943b 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -12,18 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import codecs import datetime import errno import fnmatch -import itertools import logging import os import re import shutil import sys import traceback -from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Iterable, Optional, Tuple from urllib import parse as urlparse import attr @@ -45,7 +43,12 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.oembed import OEmbedProvider -from synapse.types import JsonDict +from synapse.rest.media.v1.preview_html import ( + decode_body, + parse_html_to_open_graph, + rebase_url, +) +from synapse.types import JsonDict, UserID from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache @@ -54,21 +57,11 @@ from synapse.util.stringutils import random_string from ._base import FileInfo if TYPE_CHECKING: - from lxml import etree - from synapse.rest.media.v1.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) -_charset_match = re.compile( - br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I -) -_xml_encoding_match = re.compile( - br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I -) -_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) - OG_TAG_NAME_MAXLEN = 50 OG_TAG_VALUE_MAXLEN = 1000 @@ -231,7 +224,7 @@ class PreviewUrlResource(DirectServeJsonResource): og = await make_deferred_yieldable(observable.observe()) respond_with_json_bytes(request, 200, og, send_cors=True) - async def _do_preview(self, url: str, user: str, ts: int) -> bytes: + async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes: """Check the db, and download the URL and build a preview Args: @@ -311,7 +304,7 @@ class PreviewUrlResource(DirectServeJsonResource): # If there was no oEmbed URL (or oEmbed parsing failed), attempt # to generate the Open Graph information from the HTML. if not oembed_url or not og: - og = _calc_og(tree, media_info.uri) + og = parse_html_to_open_graph(tree, media_info.uri) await self._precache_image_url(user, media_info, og) else: @@ -360,7 +353,7 @@ class PreviewUrlResource(DirectServeJsonResource): return jsonog.encode("utf8") - async def _download_url(self, url: str, user: str) -> MediaInfo: + async def _download_url(self, url: str, user: UserID) -> MediaInfo: # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -450,7 +443,7 @@ class PreviewUrlResource(DirectServeJsonResource): ) async def _precache_image_url( - self, user: str, media_info: MediaInfo, og: JsonDict + self, user: UserID, media_info: MediaInfo, og: JsonDict ) -> None: """ Pre-cache the image (if one exists) for posterity @@ -468,7 +461,7 @@ class PreviewUrlResource(DirectServeJsonResource): # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. image_info = await self._download_url( - _rebase_url(og["og:image"], media_info.uri), user + rebase_url(og["og:image"], media_info.uri), user ) if _is_media(image_info.media_type): @@ -632,301 +625,6 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") -def _normalise_encoding(encoding: str) -> Optional[str]: - """Use the Python codec's name as the normalised entry.""" - try: - return codecs.lookup(encoding).name - except LookupError: - return None - - -def get_html_media_encodings(body: bytes, content_type: Optional[str]) -> Iterable[str]: - """ - Get potential encoding of the body based on the (presumably) HTML body or the content-type header. - - The precedence used for finding a character encoding is: - - 1. <meta> tag with a charset declared. - 2. The XML document's character encoding attribute. - 3. The Content-Type header. - 4. Fallback to utf-8. - 5. Fallback to windows-1252. - - This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector. - - Args: - body: The HTML document, as bytes. - content_type: The Content-Type header. - - Returns: - The character encoding of the body, as a string. - """ - # There's no point in returning an encoding more than once. - attempted_encodings: Set[str] = set() - - # Limit searches to the first 1kb, since it ought to be at the top. - body_start = body[:1024] - - # Check if it has an encoding set in a meta tag. - match = _charset_match.search(body_start) - if match: - encoding = _normalise_encoding(match.group(1).decode("ascii")) - if encoding: - attempted_encodings.add(encoding) - yield encoding - - # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/> - - # Check if it has an XML document with an encoding. - match = _xml_encoding_match.match(body_start) - if match: - encoding = _normalise_encoding(match.group(1).decode("ascii")) - if encoding and encoding not in attempted_encodings: - attempted_encodings.add(encoding) - yield encoding - - # Check the HTTP Content-Type header for a character set. - if content_type: - content_match = _content_type_match.match(content_type) - if content_match: - encoding = _normalise_encoding(content_match.group(1)) - if encoding and encoding not in attempted_encodings: - attempted_encodings.add(encoding) - yield encoding - - # Finally, fallback to UTF-8, then windows-1252. - for fallback in ("utf-8", "cp1252"): - if fallback not in attempted_encodings: - yield fallback - - -def decode_body( - body: bytes, uri: str, content_type: Optional[str] = None -) -> Optional["etree.Element"]: - """ - This uses lxml to parse the HTML document. - - Args: - body: The HTML document, as bytes. - uri: The URI used to download the body. - content_type: The Content-Type header. - - Returns: - The parsed HTML body, or None if an error occurred during processed. - """ - # If there's no body, nothing useful is going to be found. - if not body: - return None - - # The idea here is that multiple encodings are tried until one works. - # Unfortunately the result is never used and then LXML will decode the string - # again with the found encoding. - for encoding in get_html_media_encodings(body, content_type): - try: - body.decode(encoding) - except Exception: - pass - else: - break - else: - logger.warning("Unable to decode HTML body for %s", uri) - return None - - from lxml import etree - - # Create an HTML parser. - parser = etree.HTMLParser(recover=True, encoding=encoding) - - # Attempt to parse the body. Returns None if the body was successfully - # parsed, but no tree was found. - return etree.fromstring(body, parser) - - -def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: - """ - Calculate metadata for an HTML document. - - This uses lxml to search the HTML document for Open Graph data. - - Args: - tree: The parsed HTML document. - media_url: The URI used to download the body. - - Returns: - The Open Graph response as a dictionary. - """ - - # if we see any image URLs in the OG response, then spider them - # (although the client could choose to do this by asking for previews of those - # URLs to avoid DoSing the server) - - # "og:type" : "video", - # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", - # "og:site_name" : "YouTube", - # "og:video:type" : "application/x-shockwave-flash", - # "og:description" : "Fun stuff happening here", - # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", - # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", - # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", - # "og:video:width" : "1280" - # "og:video:height" : "720", - # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", - - og: Dict[str, Optional[str]] = {} - for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): - if "content" in tag.attrib: - # if we've got more than 50 tags, someone is taking the piss - if len(og) >= 50: - logger.warning("Skipping OG for page with too many 'og:' tags") - return {} - og[tag.attrib["property"]] = tag.attrib["content"] - - # TODO: grab article: meta tags too, e.g.: - - # "article:publisher" : "https://www.facebook.com/thethudonline" /> - # "article:author" content="https://www.facebook.com/thethudonline" /> - # "article:tag" content="baby" /> - # "article:section" content="Breaking News" /> - # "article:published_time" content="2016-03-31T19:58:24+00:00" /> - # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> - - if "og:title" not in og: - # do some basic spidering of the HTML - title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") - if title and title[0].text is not None: - og["og:title"] = title[0].text.strip() - else: - og["og:title"] = None - - if "og:image" not in og: - # TODO: extract a favicon failing all else - meta_image = tree.xpath( - "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" - ) - if meta_image: - og["og:image"] = _rebase_url(meta_image[0], media_uri) - else: - # TODO: consider inlined CSS styles as well as width & height attribs - images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") - images = sorted( - images, - key=lambda i: ( - -1 * float(i.attrib["width"]) * float(i.attrib["height"]) - ), - ) - if not images: - images = tree.xpath("//img[@src]") - if images: - og["og:image"] = images[0].attrib["src"] - - if "og:description" not in og: - meta_description = tree.xpath( - "//*/meta" - "[translate(@name, 'DESCRIPTION', 'description')='description']" - "/@content" - ) - if meta_description: - og["og:description"] = meta_description[0] - else: - og["og:description"] = _calc_description(tree) - elif og["og:description"]: - # This must be a non-empty string at this point. - assert isinstance(og["og:description"], str) - og["og:description"] = summarize_paragraphs([og["og:description"]]) - - # TODO: delete the url downloads to stop diskfilling, - # as we only ever cared about its OG - return og - - -def _calc_description(tree: "etree.Element") -> Optional[str]: - """ - Calculate a text description based on an HTML document. - - Grabs any text nodes which are inside the <body/> tag, unless they are within - an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or - if they are within a <script/> or <style/> tag. - - This is a very very very coarse approximation to a plain text render of the page. - - Args: - tree: The parsed HTML document. - - Returns: - The plain text description, or None if one cannot be generated. - """ - # We don't just use XPATH here as that is slow on some machines. - - from lxml import etree - - TAGS_TO_REMOVE = ( - "header", - "nav", - "aside", - "footer", - "script", - "noscript", - "style", - etree.Comment, - ) - - # Split all the text nodes into paragraphs (by splitting on new - # lines) - text_nodes = ( - re.sub(r"\s+", "\n", el).strip() - for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) - ) - return summarize_paragraphs(text_nodes) - - -def _iterate_over_text( - tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] -) -> Generator[str, None, None]: - """Iterate over the tree returning text nodes in a depth first fashion, - skipping text nodes inside certain tags. - """ - # This is basically a stack that we extend using itertools.chain. - # This will either consist of an element to iterate over *or* a string - # to be returned. - elements = iter([tree]) - while True: - el = next(elements, None) - if el is None: - return - - if isinstance(el, str): - yield el - elif el.tag not in tags_to_ignore: - # el.text is the text before the first child, so we can immediately - # return it if the text exists. - if el.text: - yield el.text - - # We add to the stack all the elements children, interspersed with - # each child's tail text (if it exists). The tail text of a node - # is text that comes *after* the node, so we always include it even - # if we ignore the child node. - elements = itertools.chain( - itertools.chain.from_iterable( # Basically a flatmap - [child, child.tail] if child.tail else [child] - for child in el.iterchildren() - ), - elements, - ) - - -def _rebase_url(url: str, base: str) -> str: - base_parts = list(urlparse.urlparse(base)) - url_parts = list(urlparse.urlparse(url)) - if not url_parts[0]: # fix up schema - url_parts[0] = base_parts[0] or "http" - if not url_parts[1]: # fix up hostname - url_parts[1] = base_parts[1] - if not url_parts[2].startswith("/"): - url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2] - return urlparse.urlunparse(url_parts) - - def _is_media(content_type: str) -> bool: return content_type.lower().startswith("image/") @@ -940,68 +638,3 @@ def _is_html(content_type: str) -> bool: def _is_json(content_type: str) -> bool: return content_type.lower().startswith("application/json") - - -def summarize_paragraphs( - text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 -) -> Optional[str]: - """ - Try to get a summary respecting first paragraph and then word boundaries. - - Args: - text_nodes: The paragraphs to summarize. - min_size: The minimum number of words to include. - max_size: The maximum number of words to include. - - Returns: - A summary of the text nodes, or None if that was not possible. - """ - - # TODO: Respect sentences? - - description = "" - - # Keep adding paragraphs until we get to the MIN_SIZE. - for text_node in text_nodes: - if len(description) < min_size: - text_node = re.sub(r"[\t \r\n]+", " ", text_node) - description += text_node + "\n\n" - else: - break - - description = description.strip() - description = re.sub(r"[\t ]+", " ", description) - description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description) - - # If the concatenation of paragraphs to get above MIN_SIZE - # took us over MAX_SIZE, then we need to truncate mid paragraph - if len(description) > max_size: - new_desc = "" - - # This splits the paragraph into words, but keeping the - # (preceding) whitespace intact so we can easily concat - # words back together. - for match in re.finditer(r"\s*\S+", description): - word = match.group() - - # Keep adding words while the total length is less than - # MAX_SIZE. - if len(word) + len(new_desc) < max_size: - new_desc += word - else: - # At this point the next word *will* take us over - # MAX_SIZE, but we also want to ensure that its not - # a huge word. If it is add it anyway and we'll - # truncate later. - if len(new_desc) < min_size: - new_desc += word - break - - # Double check that we're not over the limit - if len(new_desc) > max_size: - new_desc = new_desc[:max_size] - - # We always add an ellipsis because at the very least - # we chopped mid paragraph. - description = new_desc.strip() + "…" - return description if description else None diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 46701a8b83..5e17664b5b 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -101,8 +101,8 @@ class Thumbnailer: fits within the given rectangle:: (w_in / h_in) = (w_out / h_out) - w_out = min(w_max, h_max * (w_in / h_in)) - h_out = min(h_max, w_max * (h_in / w_in)) + w_out = max(min(w_max, h_max * (w_in / h_in)), 1) + h_out = max(min(h_max, w_max * (h_in / w_in)), 1) Args: max_width: The largest possible width. @@ -110,9 +110,9 @@ class Thumbnailer: """ if max_width * self.height < max_height * self.width: - return max_width, (max_width * self.height) // self.width + return max_width, max((max_width * self.height) // self.width, 1) else: - return (max_height * self.width) // self.height, max_height + return max((max_height * self.width) // self.height, 1), max_height def _resize(self, width: int, height: int) -> Image.Image: # 1-bit or 8-bit color palette images need converting to RGB diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index edbf5ce5d0..04b035a1b1 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -34,8 +34,7 @@ class WellKnownBuilder: self._config = hs.config def get_well_known(self) -> Optional[JsonDict]: - # if we don't have a public_baseurl, we can't help much here. - if self._config.server.public_baseurl is None: + if not self._config.server.serve_client_wellknown: return None result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}} |