From e8ae94a22367a81049582dfdb16c743a45ca4e8b Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 29 Nov 2021 23:19:45 +0100 Subject: Convert status codes to `HTTPStatus` in `synapse.rest.admin` (#11452) --- synapse/rest/admin/__init__.py | 19 +-- synapse/rest/admin/_base.py | 3 +- synapse/rest/admin/devices.py | 21 ++-- synapse/rest/admin/event_reports.py | 21 ++-- synapse/rest/admin/groups.py | 5 +- synapse/rest/admin/media.py | 53 ++++----- synapse/rest/admin/registration_tokens.py | 51 +++++--- synapse/rest/admin/rooms.py | 68 ++++++----- synapse/rest/admin/server_notice_servlet.py | 11 +- synapse/rest/admin/statistics.py | 21 ++-- synapse/rest/admin/users.py | 173 ++++++++++++++++++---------- 11 files changed, 275 insertions(+), 171 deletions(-) (limited to 'synapse/rest') diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index ee4a5e481b..c51a029bf3 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -17,6 +17,7 @@ import logging import platform +from http import HTTPStatus from typing import TYPE_CHECKING, Optional, Tuple import synapse @@ -98,7 +99,7 @@ class VersionServlet(RestServlet): } def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - return 200, self.res + return HTTPStatus.OK, self.res class PurgeHistoryRestServlet(RestServlet): @@ -130,7 +131,7 @@ class PurgeHistoryRestServlet(RestServlet): event = await self.store.get_event(event_id) if event.room_id != room_id: - raise SynapseError(400, "Event is for wrong room.") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Event is for wrong room.") # RoomStreamToken expects [int] not Optional[int] assert event.internal_metadata.stream_ordering is not None @@ -144,7 +145,9 @@ class PurgeHistoryRestServlet(RestServlet): ts = body["purge_up_to_ts"] if not isinstance(ts, int): raise SynapseError( - 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON + HTTPStatus.BAD_REQUEST, + "purge_up_to_ts must be an int", + errcode=Codes.BAD_JSON, ) stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) @@ -160,7 +163,9 @@ class PurgeHistoryRestServlet(RestServlet): stream_ordering, ) raise SynapseError( - 404, "there is no event to be purged", errcode=Codes.NOT_FOUND + HTTPStatus.NOT_FOUND, + "there is no event to be purged", + errcode=Codes.NOT_FOUND, ) (stream, topo, _event_id) = r token = "t%d-%d" % (topo, stream) @@ -173,7 +178,7 @@ class PurgeHistoryRestServlet(RestServlet): ) else: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "must specify purge_up_to_event_id or purge_up_to_ts", errcode=Codes.BAD_JSON, ) @@ -182,7 +187,7 @@ class PurgeHistoryRestServlet(RestServlet): room_id, token, delete_local_events=delete_local_events ) - return 200, {"purge_id": purge_id} + return HTTPStatus.OK, {"purge_id": purge_id} class PurgeHistoryStatusRestServlet(RestServlet): @@ -201,7 +206,7 @@ class PurgeHistoryStatusRestServlet(RestServlet): if purge_status is None: raise NotFoundError("purge id '%s' not found" % purge_id) - return 200, purge_status.asdict() + return HTTPStatus.OK, purge_status.asdict() ######################################################################################## diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index d9a2f6ca15..399b205aaf 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -13,6 +13,7 @@ # limitations under the License. import re +from http import HTTPStatus from typing import Iterable, Pattern from synapse.api.auth import Auth @@ -62,4 +63,4 @@ async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None: """ is_admin = await auth.is_server_admin(user_id) if not is_admin: - raise AuthError(403, "You are not a server admin") + raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 80fbf32f17..2e5a6600d3 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError @@ -53,7 +54,7 @@ class DeviceRestServlet(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -62,7 +63,7 @@ class DeviceRestServlet(RestServlet): device = await self.device_handler.get_device( target_user.to_string(), device_id ) - return 200, device + return HTTPStatus.OK, device async def on_DELETE( self, request: SynapseRequest, user_id: str, device_id: str @@ -71,14 +72,14 @@ class DeviceRestServlet(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: raise NotFoundError("Unknown user") await self.device_handler.delete_device(target_user.to_string(), device_id) - return 200, {} + return HTTPStatus.OK, {} async def on_PUT( self, request: SynapseRequest, user_id: str, device_id: str @@ -87,7 +88,7 @@ class DeviceRestServlet(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -97,7 +98,7 @@ class DeviceRestServlet(RestServlet): await self.device_handler.update_device( target_user.to_string(), device_id, body ) - return 200, {} + return HTTPStatus.OK, {} class DevicesRestServlet(RestServlet): @@ -124,14 +125,14 @@ class DevicesRestServlet(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: raise NotFoundError("Unknown user") devices = await self.device_handler.get_devices_by_user(target_user.to_string()) - return 200, {"devices": devices, "total": len(devices)} + return HTTPStatus.OK, {"devices": devices, "total": len(devices)} class DeleteDevicesRestServlet(RestServlet): @@ -155,7 +156,7 @@ class DeleteDevicesRestServlet(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -167,4 +168,4 @@ class DeleteDevicesRestServlet(RestServlet): await self.device_handler.delete_devices( target_user.to_string(), body["devices"] ) - return 200, {} + return HTTPStatus.OK, {} diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index bbfcaf723b..5ee8b11110 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -66,21 +67,23 @@ class EventReportsRestServlet(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "The start parameter must be a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "The limit parameter must be a positive integer.", errcode=Codes.INVALID_PARAM, ) if direction not in ("f", "b"): raise SynapseError( - 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "Unknown direction: %s" % (direction,), + errcode=Codes.INVALID_PARAM, ) event_reports, total = await self.store.get_event_reports_paginate( @@ -90,7 +93,7 @@ class EventReportsRestServlet(RestServlet): if (start + limit) < total: ret["next_token"] = start + len(event_reports) - return 200, ret + return HTTPStatus.OK, ret class EventReportDetailRestServlet(RestServlet): @@ -127,13 +130,17 @@ class EventReportDetailRestServlet(RestServlet): try: resolved_report_id = int(report_id) except ValueError: - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) if resolved_report_id < 0: - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) ret = await self.store.get_event_report(resolved_report_id) if not ret: raise NotFoundError("Event report not found") - return 200, ret + return HTTPStatus.OK, ret diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py index 68a3ba3cb7..a27110388f 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import SynapseError @@ -43,7 +44,7 @@ class DeleteGroupAdminRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) if not self.is_mine_id(group_id): - raise SynapseError(400, "Can only delete local groups") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local groups") await self.group_server.delete_group(group_id, requester.user.to_string()) - return 200, {} + return HTTPStatus.OK, {} diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 30a687d234..9e23e2d8fc 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError @@ -62,7 +63,7 @@ class QuarantineMediaInRoom(RestServlet): room_id, requester.user.to_string() ) - return 200, {"num_quarantined": num_quarantined} + return HTTPStatus.OK, {"num_quarantined": num_quarantined} class QuarantineMediaByUser(RestServlet): @@ -89,7 +90,7 @@ class QuarantineMediaByUser(RestServlet): user_id, requester.user.to_string() ) - return 200, {"num_quarantined": num_quarantined} + return HTTPStatus.OK, {"num_quarantined": num_quarantined} class QuarantineMediaByID(RestServlet): @@ -118,7 +119,7 @@ class QuarantineMediaByID(RestServlet): server_name, media_id, requester.user.to_string() ) - return 200, {} + return HTTPStatus.OK, {} class UnquarantineMediaByID(RestServlet): @@ -147,7 +148,7 @@ class UnquarantineMediaByID(RestServlet): # Remove from quarantine this media id await self.store.quarantine_media_by_id(server_name, media_id, None) - return 200, {} + return HTTPStatus.OK, {} class ProtectMediaByID(RestServlet): @@ -170,7 +171,7 @@ class ProtectMediaByID(RestServlet): # Protect this media id await self.store.mark_local_media_as_safe(media_id, safe=True) - return 200, {} + return HTTPStatus.OK, {} class UnprotectMediaByID(RestServlet): @@ -193,7 +194,7 @@ class UnprotectMediaByID(RestServlet): # Unprotect this media id await self.store.mark_local_media_as_safe(media_id, safe=False) - return 200, {} + return HTTPStatus.OK, {} class ListMediaInRoom(RestServlet): @@ -211,11 +212,11 @@ class ListMediaInRoom(RestServlet): requester = await self.auth.get_user_by_req(request) is_admin = await self.auth.is_server_admin(requester.user) if not is_admin: - raise AuthError(403, "You are not a server admin") + raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id) - return 200, {"local": local_mxcs, "remote": remote_mxcs} + return HTTPStatus.OK, {"local": local_mxcs, "remote": remote_mxcs} class PurgeMediaCacheRestServlet(RestServlet): @@ -233,13 +234,13 @@ class PurgeMediaCacheRestServlet(RestServlet): if before_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts must be a positive integer.", errcode=Codes.INVALID_PARAM, ) elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts you provided is from the year 1970. " + "Double check that you are providing a timestamp in milliseconds.", errcode=Codes.INVALID_PARAM, @@ -247,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet): ret = await self.media_repository.delete_old_remote_media(before_ts) - return 200, ret + return HTTPStatus.OK, ret class DeleteMediaByID(RestServlet): @@ -267,7 +268,7 @@ class DeleteMediaByID(RestServlet): await assert_requester_is_admin(self.auth, request) if self.server_name != server_name: - raise SynapseError(400, "Can only delete local media") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") if await self.store.get_local_media(media_id) is None: raise NotFoundError("Unknown media") @@ -277,7 +278,7 @@ class DeleteMediaByID(RestServlet): deleted_media, total = await self.media_repository.delete_local_media_ids( [media_id] ) - return 200, {"deleted_media": deleted_media, "total": total} + return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} class DeleteMediaByDateSize(RestServlet): @@ -304,26 +305,26 @@ class DeleteMediaByDateSize(RestServlet): if before_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts must be a positive integer.", errcode=Codes.INVALID_PARAM, ) elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts you provided is from the year 1970. " + "Double check that you are providing a timestamp in milliseconds.", errcode=Codes.INVALID_PARAM, ) if size_gt < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter size_gt must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if self.server_name != server_name: - raise SynapseError(400, "Can only delete local media") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") logging.info( "Deleting local media by timestamp: %s, size larger than: %s, keep profile media: %s" @@ -333,7 +334,7 @@ class DeleteMediaByDateSize(RestServlet): deleted_media, total = await self.media_repository.delete_old_local_media( before_ts, size_gt, keep_profiles ) - return 200, {"deleted_media": deleted_media, "total": total} + return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} class UserMediaRestServlet(RestServlet): @@ -369,7 +370,7 @@ class UserMediaRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") user = await self.store.get_user_by_id(user_id) if user is None: @@ -380,14 +381,14 @@ class UserMediaRestServlet(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -425,7 +426,7 @@ class UserMediaRestServlet(RestServlet): if (start + limit) < total: ret["next_token"] = start + len(media) - return 200, ret + return HTTPStatus.OK, ret async def on_DELETE( self, request: SynapseRequest, user_id: str @@ -436,7 +437,7 @@ class UserMediaRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") user = await self.store.get_user_by_id(user_id) if user is None: @@ -447,14 +448,14 @@ class UserMediaRestServlet(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -492,7 +493,7 @@ class UserMediaRestServlet(RestServlet): ([row["media_id"] for row in media]) ) - return 200, {"deleted_media": deleted_media, "total": total} + return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index aba48f6e7b..891b98c088 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -14,6 +14,7 @@ import logging import string +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -77,7 +78,7 @@ class ListRegistrationTokensRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) valid = parse_boolean(request, "valid") token_list = await self.store.get_registration_tokens(valid) - return 200, {"registration_tokens": token_list} + return HTTPStatus.OK, {"registration_tokens": token_list} class NewRegistrationTokenRestServlet(RestServlet): @@ -123,16 +124,20 @@ class NewRegistrationTokenRestServlet(RestServlet): if "token" in body: token = body["token"] if not isinstance(token, str): - raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "token must be a string", + Codes.INVALID_PARAM, + ) if not (0 < len(token) <= 64): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "token must not be empty and must not be longer than 64 characters", Codes.INVALID_PARAM, ) if not set(token).issubset(self.allowed_chars_set): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "token must consist only of characters matched by the regex [A-Za-z0-9-_]", Codes.INVALID_PARAM, ) @@ -142,11 +147,13 @@ class NewRegistrationTokenRestServlet(RestServlet): length = body.get("length", 16) if not isinstance(length, int): raise SynapseError( - 400, "length must be an integer", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "length must be an integer", + Codes.INVALID_PARAM, ) if not (0 < length <= 64): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "length must be greater than zero and not greater than 64", Codes.INVALID_PARAM, ) @@ -162,7 +169,7 @@ class NewRegistrationTokenRestServlet(RestServlet): or (isinstance(uses_allowed, int) and uses_allowed >= 0) ): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "uses_allowed must be a non-negative integer or null", Codes.INVALID_PARAM, ) @@ -170,11 +177,15 @@ class NewRegistrationTokenRestServlet(RestServlet): expiry_time = body.get("expiry_time", None) if not isinstance(expiry_time, (int, type(None))): raise SynapseError( - 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must be an integer or null", + Codes.INVALID_PARAM, ) if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): raise SynapseError( - 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must not be in the past", + Codes.INVALID_PARAM, ) created = await self.store.create_registration_token( @@ -182,7 +193,9 @@ class NewRegistrationTokenRestServlet(RestServlet): ) if not created: raise SynapseError( - 400, f"Token already exists: {token}", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + f"Token already exists: {token}", + Codes.INVALID_PARAM, ) resp = { @@ -192,7 +205,7 @@ class NewRegistrationTokenRestServlet(RestServlet): "completed": 0, "expiry_time": expiry_time, } - return 200, resp + return HTTPStatus.OK, resp class RegistrationTokenRestServlet(RestServlet): @@ -261,7 +274,7 @@ class RegistrationTokenRestServlet(RestServlet): if token_info is None: raise NotFoundError(f"No such registration token: {token}") - return 200, token_info + return HTTPStatus.OK, token_info async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: """Update a registration token.""" @@ -277,7 +290,7 @@ class RegistrationTokenRestServlet(RestServlet): or (isinstance(uses_allowed, int) and uses_allowed >= 0) ): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "uses_allowed must be a non-negative integer or null", Codes.INVALID_PARAM, ) @@ -287,11 +300,15 @@ class RegistrationTokenRestServlet(RestServlet): expiry_time = body["expiry_time"] if not isinstance(expiry_time, (int, type(None))): raise SynapseError( - 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must be an integer or null", + Codes.INVALID_PARAM, ) if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): raise SynapseError( - 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must not be in the past", + Codes.INVALID_PARAM, ) new_attributes["expiry_time"] = expiry_time @@ -307,7 +324,7 @@ class RegistrationTokenRestServlet(RestServlet): if token_info is None: raise NotFoundError(f"No such registration token: {token}") - return 200, token_info + return HTTPStatus.OK, token_info async def on_DELETE( self, request: SynapseRequest, token: str @@ -316,6 +333,6 @@ class RegistrationTokenRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if await self.store.delete_registration_token(token): - return 200, {} + return HTTPStatus.OK, {} raise NotFoundError(f"No such registration token: {token}") diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index a89dda1ba5..6bbc5510f0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -102,7 +102,9 @@ class RoomRestV2Servlet(RestServlet): ) if not RoomID.is_valid(room_id): - raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) if not await self._store.get_room(room_id): raise NotFoundError("Unknown room id %s" % (room_id,)) @@ -118,7 +120,7 @@ class RoomRestV2Servlet(RestServlet): force_purge=force_purge, ) - return 200, {"delete_id": delete_id} + return HTTPStatus.OK, {"delete_id": delete_id} class DeleteRoomStatusByRoomIdRestServlet(RestServlet): @@ -137,7 +139,9 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): await assert_requester_is_admin(self._auth, request) if not RoomID.is_valid(room_id): - raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id) if delete_ids is None: @@ -153,7 +157,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): **delete.asdict(), } ] - return 200, {"results": cast(JsonDict, response)} + return HTTPStatus.OK, {"results": cast(JsonDict, response)} class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): @@ -175,7 +179,7 @@ class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): if delete_status is None: raise NotFoundError("delete id '%s' not found" % delete_id) - return 200, cast(JsonDict, delete_status.asdict()) + return HTTPStatus.OK, cast(JsonDict, delete_status.asdict()) class ListRoomRestServlet(RestServlet): @@ -217,7 +221,7 @@ class ListRoomRestServlet(RestServlet): RoomSortOrder.STATE_EVENTS.value, ): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Unknown value for order_by: %s" % (order_by,), errcode=Codes.INVALID_PARAM, ) @@ -225,7 +229,7 @@ class ListRoomRestServlet(RestServlet): search_term = parse_string(request, "search_term", encoding="utf-8") if search_term == "": raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "search_term cannot be an empty string", errcode=Codes.INVALID_PARAM, ) @@ -233,7 +237,9 @@ class ListRoomRestServlet(RestServlet): direction = parse_string(request, "dir", default="f") if direction not in ("f", "b"): raise SynapseError( - 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "Unknown direction: %s" % (direction,), + errcode=Codes.INVALID_PARAM, ) reverse_order = True if direction == "b" else False @@ -265,7 +271,7 @@ class ListRoomRestServlet(RestServlet): else: response["prev_batch"] = 0 - return 200, response + return HTTPStatus.OK, response class RoomRestServlet(RestServlet): @@ -310,7 +316,7 @@ class RoomRestServlet(RestServlet): members = await self.store.get_users_in_room(room_id) ret["joined_local_devices"] = await self.store.count_devices_by_users(members) - return 200, ret + return HTTPStatus.OK, ret async def on_DELETE( self, request: SynapseRequest, room_id: str @@ -386,7 +392,7 @@ class RoomRestServlet(RestServlet): # See https://github.com/python/mypy/issues/4976#issuecomment-579883622 # for some discussion on why this is necessary. Either way, # `ret` is an opaque dictionary blob as far as the rest of the app cares. - return 200, cast(JsonDict, ret) + return HTTPStatus.OK, cast(JsonDict, ret) class RoomMembersRestServlet(RestServlet): @@ -413,7 +419,7 @@ class RoomMembersRestServlet(RestServlet): members = await self.store.get_users_in_room(room_id) ret = {"members": members, "total": len(members)} - return 200, ret + return HTTPStatus.OK, ret class RoomStateRestServlet(RestServlet): @@ -452,7 +458,7 @@ class RoomStateRestServlet(RestServlet): ) ret = {"state": room_state} - return 200, ret + return HTTPStatus.OK, ret class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): @@ -481,7 +487,10 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): target_user = UserID.from_string(content["user_id"]) if not self.hs.is_mine(target_user): - raise SynapseError(400, "This endpoint can only be used with local users") + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "This endpoint can only be used with local users", + ) if not await self.admin_handler.get_user(target_user): raise NotFoundError("User not found") @@ -527,7 +536,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): ratelimit=False, ) - return 200, {"room_id": room_id} + return HTTPStatus.OK, {"room_id": room_id} class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): @@ -568,7 +577,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): # Figure out which local users currently have power in the room, if any. room_state = await self.state_handler.get_current_state(room_id) if not room_state: - raise SynapseError(400, "Server not in room") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room") create_event = room_state[(EventTypes.Create, "")] power_levels = room_state.get((EventTypes.PowerLevels, "")) @@ -582,7 +591,9 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): admin_users.sort(key=lambda user: user_power[user]) if not admin_users: - raise SynapseError(400, "No local admin user in room") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "No local admin user in room" + ) admin_user_id = None @@ -599,7 +610,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): if not admin_user_id: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "No local admin user in room", ) @@ -610,7 +621,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): admin_user_id = create_event.sender if not self.is_mine_id(admin_user_id): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "No local admin user in room", ) @@ -639,7 +650,8 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): except AuthError: # The admin user we found turned out not to have enough power. raise SynapseError( - 400, "No local admin user in room with power to update power levels." + HTTPStatus.BAD_REQUEST, + "No local admin user in room with power to update power levels.", ) # Now we check if the user we're granting admin rights to is already in @@ -653,7 +665,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): ) if is_joined: - return 200, {} + return HTTPStatus.OK, {} join_rules = room_state.get((EventTypes.JoinRules, "")) is_public = False @@ -661,7 +673,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC if is_public: - return 200, {} + return HTTPStatus.OK, {} await self.room_member_handler.update_membership( fake_requester, @@ -670,7 +682,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): action=Membership.INVITE, ) - return 200, {} + return HTTPStatus.OK, {} class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): @@ -702,7 +714,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): room_id, _ = await self.resolve_room_id(room_identifier) deleted_count = await self.store.delete_forward_extremities_for_room(room_id) - return 200, {"deleted": deleted_count} + return HTTPStatus.OK, {"deleted": deleted_count} async def on_GET( self, request: SynapseRequest, room_identifier: str @@ -713,7 +725,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): room_id, _ = await self.resolve_room_id(room_identifier) extremities = await self.store.get_forward_extremities_for_room(room_id) - return 200, {"count": len(extremities), "results": extremities} + return HTTPStatus.OK, {"count": len(extremities), "results": extremities} class RoomEventContextServlet(RestServlet): @@ -762,7 +774,9 @@ class RoomEventContextServlet(RestServlet): ) if not results: - raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + raise SynapseError( + HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND + ) time_now = self.clock.time_msec() results["events_before"] = await self._event_serializer.serialize_events( @@ -781,7 +795,7 @@ class RoomEventContextServlet(RestServlet): bundle_relations=False, ) - return 200, results + return HTTPStatus.OK, results class BlockRoomRestServlet(RestServlet): diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index 19f84f33f2..b295fb078b 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus from typing import TYPE_CHECKING, Awaitable, Optional, Tuple from synapse.api.constants import EventTypes @@ -82,11 +83,15 @@ class SendServerNoticeServlet(RestServlet): # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the # admin api). if not self.server_notices_manager.is_enabled(): - raise SynapseError(400, "Server notices are not enabled on this server") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Server notices are not enabled on this server" + ) target_user = UserID.from_string(body["user_id"]) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Server notices can only be sent to local users") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users" + ) if not await self.admin_handler.get_user(target_user): raise NotFoundError("User not found") @@ -99,7 +104,7 @@ class SendServerNoticeServlet(RestServlet): txn_id=txn_id, ) - return 200, {"event_id": event.event_id} + return HTTPStatus.OK, {"event_id": event.event_id} def on_PUT( self, request: SynapseRequest, txn_id: str diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 948de94ccd..ca41fd45f2 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, SynapseError @@ -53,7 +54,7 @@ class UserMediaStatisticsRestServlet(RestServlet): UserSortOrder.DISPLAYNAME.value, ): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Unknown value for order_by: %s" % (order_by,), errcode=Codes.INVALID_PARAM, ) @@ -61,7 +62,7 @@ class UserMediaStatisticsRestServlet(RestServlet): start = parse_integer(request, "from", default=0) if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -69,7 +70,7 @@ class UserMediaStatisticsRestServlet(RestServlet): limit = parse_integer(request, "limit", default=100) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -77,7 +78,7 @@ class UserMediaStatisticsRestServlet(RestServlet): from_ts = parse_integer(request, "from_ts", default=0) if from_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from_ts must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -86,13 +87,13 @@ class UserMediaStatisticsRestServlet(RestServlet): if until_ts is not None: if until_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter until_ts must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if until_ts <= from_ts: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter until_ts must be greater than from_ts.", errcode=Codes.INVALID_PARAM, ) @@ -100,7 +101,7 @@ class UserMediaStatisticsRestServlet(RestServlet): search_term = parse_string(request, "search_term") if search_term == "": raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter search_term cannot be an empty string.", errcode=Codes.INVALID_PARAM, ) @@ -108,7 +109,9 @@ class UserMediaStatisticsRestServlet(RestServlet): direction = parse_string(request, "dir", default="f") if direction not in ("f", "b"): raise SynapseError( - 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "Unknown direction: %s" % (direction,), + errcode=Codes.INVALID_PARAM, ) users_media, total = await self.store.get_users_media_usage_paginate( @@ -118,4 +121,4 @@ class UserMediaStatisticsRestServlet(RestServlet): if (start + limit) < total: ret["next_token"] = start + len(users_media) - return 200, ret + return HTTPStatus.OK, ret diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index ccd9a2a175..2a60b602b1 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -79,14 +79,14 @@ class UsersRestServletV2(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -122,7 +122,7 @@ class UsersRestServletV2(RestServlet): if (start + limit) < total: ret["next_token"] = str(start + len(users)) - return 200, ret + return HTTPStatus.OK, ret class UserRestServletV2(RestServlet): @@ -172,14 +172,14 @@ class UserRestServletV2(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") ret = await self.admin_handler.get_user(target_user) if not ret: raise NotFoundError("User not found") - return 200, ret + return HTTPStatus.OK, ret async def on_PUT( self, request: SynapseRequest, user_id: str @@ -191,7 +191,10 @@ class UserRestServletV2(RestServlet): body = parse_json_object_from_request(request) if not self.hs.is_mine(target_user): - raise SynapseError(400, "This endpoint can only be used with local users") + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "This endpoint can only be used with local users", + ) user = await self.admin_handler.get_user(target_user) user_id = target_user.to_string() @@ -210,7 +213,7 @@ class UserRestServletV2(RestServlet): user_type = body.get("user_type", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(400, "Invalid user type") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type") set_admin_to = body.get("admin", False) if not isinstance(set_admin_to, bool): @@ -223,11 +226,13 @@ class UserRestServletV2(RestServlet): password = body.get("password", None) if password is not None: if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") deactivate = body.get("deactivated", False) if not isinstance(deactivate, bool): - raise SynapseError(400, "'deactivated' parameter is not of type boolean") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean" + ) # convert List[Dict[str, str]] into List[Tuple[str, str]] if external_ids is not None: @@ -282,7 +287,9 @@ class UserRestServletV2(RestServlet): user_id, ) except ExternalIDReuseException: - raise SynapseError(409, "External id is already in use.") + raise SynapseError( + HTTPStatus.CONFLICT, "External id is already in use." + ) if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( @@ -293,7 +300,9 @@ class UserRestServletV2(RestServlet): if set_admin_to != user["admin"]: auth_user = requester.user if target_user == auth_user and not set_admin_to: - raise SynapseError(400, "You may not demote yourself.") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "You may not demote yourself." + ) await self.store.set_server_admin(target_user, set_admin_to) @@ -319,7 +328,8 @@ class UserRestServletV2(RestServlet): and self.auth_handler.can_change_password() ): raise SynapseError( - 400, "Must provide a password to re-activate an account." + HTTPStatus.BAD_REQUEST, + "Must provide a password to re-activate an account.", ) await self.deactivate_account_handler.activate_account( @@ -332,7 +342,7 @@ class UserRestServletV2(RestServlet): user = await self.admin_handler.get_user(target_user) assert user is not None - return 200, user + return HTTPStatus.OK, user else: # create user displayname = body.get("displayname", None) @@ -381,7 +391,9 @@ class UserRestServletV2(RestServlet): user_id, ) except ExternalIDReuseException: - raise SynapseError(409, "External id is already in use.") + raise SynapseError( + HTTPStatus.CONFLICT, "External id is already in use." + ) if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( @@ -429,51 +441,61 @@ class UserRegisterServlet(RestServlet): nonce = secrets.token_hex(64) self.nonces[nonce] = int(self.reactor.seconds()) - return 200, {"nonce": nonce} + return HTTPStatus.OK, {"nonce": nonce} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._clear_old_nonces() if not self.hs.config.registration.registration_shared_secret: - raise SynapseError(400, "Shared secret registration is not enabled") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Shared secret registration is not enabled" + ) body = parse_json_object_from_request(request) if "nonce" not in body: - raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON) + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "nonce must be specified", + errcode=Codes.BAD_JSON, + ) nonce = body["nonce"] if nonce not in self.nonces: - raise SynapseError(400, "unrecognised nonce") + raise SynapseError(HTTPStatus.BAD_REQUEST, "unrecognised nonce") # Delete the nonce, so it can't be reused, even if it's invalid del self.nonces[nonce] if "username" not in body: raise SynapseError( - 400, "username must be specified", errcode=Codes.BAD_JSON + HTTPStatus.BAD_REQUEST, + "username must be specified", + errcode=Codes.BAD_JSON, ) else: if not isinstance(body["username"], str) or len(body["username"]) > 512: - raise SynapseError(400, "Invalid username") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username") username = body["username"].encode("utf-8") if b"\x00" in username: - raise SynapseError(400, "Invalid username") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username") if "password" not in body: raise SynapseError( - 400, "password must be specified", errcode=Codes.BAD_JSON + HTTPStatus.BAD_REQUEST, + "password must be specified", + errcode=Codes.BAD_JSON, ) else: password = body["password"] if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") password_bytes = password.encode("utf-8") if b"\x00" in password_bytes: - raise SynapseError(400, "Invalid password") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") password_hash = await self.auth_handler.hash(password) @@ -482,10 +504,12 @@ class UserRegisterServlet(RestServlet): displayname = body.get("displayname", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(400, "Invalid user type") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type") if "mac" not in body: - raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "mac must be specified", errcode=Codes.BAD_JSON + ) got_mac = body["mac"] @@ -507,7 +531,7 @@ class UserRegisterServlet(RestServlet): want_mac = want_mac_builder.hexdigest() if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")): - raise SynapseError(403, "HMAC incorrect") + raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication from synapse.rest.client.register import RegisterRestServlet @@ -524,7 +548,7 @@ class UserRegisterServlet(RestServlet): ) result = await register._create_registration_details(user_id, body) - return 200, result + return HTTPStatus.OK, result class WhoisRestServlet(RestServlet): @@ -552,11 +576,11 @@ class WhoisRestServlet(RestServlet): await assert_user_is_admin(self.auth, auth_user) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only whois a local user") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") ret = await self.admin_handler.get_whois(target_user) - return 200, ret + return HTTPStatus.OK, ret class DeactivateAccountRestServlet(RestServlet): @@ -575,7 +599,9 @@ class DeactivateAccountRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) if not self.is_mine(UserID.from_string(target_user_id)): - raise SynapseError(400, "Can only deactivate local users") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Can only deactivate local users" + ) if not await self.store.get_user_by_id(target_user_id): raise NotFoundError("User not found") @@ -597,7 +623,7 @@ class DeactivateAccountRestServlet(RestServlet): else: id_server_unbind_result = "no-support" - return 200, {"id_server_unbind_result": id_server_unbind_result} + return HTTPStatus.OK, {"id_server_unbind_result": id_server_unbind_result} class AccountValidityRenewServlet(RestServlet): @@ -620,7 +646,7 @@ class AccountValidityRenewServlet(RestServlet): if "user_id" not in body: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Missing property 'user_id' in the request body", ) @@ -631,7 +657,7 @@ class AccountValidityRenewServlet(RestServlet): ) res = {"expiration_ts": expiration_ts} - return 200, res + return HTTPStatus.OK, res class ResetPasswordRestServlet(RestServlet): @@ -678,7 +704,7 @@ class ResetPasswordRestServlet(RestServlet): await self._set_password_handler.set_password( target_user_id, new_password_hash, logout_devices, requester ) - return 200, {} + return HTTPStatus.OK, {} class SearchUsersRestServlet(RestServlet): @@ -712,16 +738,16 @@ class SearchUsersRestServlet(RestServlet): # To allow all users to get the users list # if not is_admin and target_user != auth_user: - # raise AuthError(403, "You are not a server admin") + # raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only users a local user") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user") term = parse_string(request, "term", required=True) logger.info("term: %s ", term) ret = await self.store.search_users(term) - return 200, ret + return HTTPStatus.OK, ret class UserAdminServlet(RestServlet): @@ -765,11 +791,14 @@ class UserAdminServlet(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Only local users can be admins of this homeserver") + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Only local users can be admins of this homeserver", + ) is_admin = await self.store.is_server_admin(target_user) - return 200, {"admin": is_admin} + return HTTPStatus.OK, {"admin": is_admin} async def on_PUT( self, request: SynapseRequest, user_id: str @@ -785,16 +814,19 @@ class UserAdminServlet(RestServlet): assert_params_in_dict(body, ["admin"]) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Only local users can be admins of this homeserver") + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Only local users can be admins of this homeserver", + ) set_admin_to = bool(body["admin"]) if target_user == auth_user and not set_admin_to: - raise SynapseError(400, "You may not demote yourself.") + raise SynapseError(HTTPStatus.BAD_REQUEST, "You may not demote yourself.") await self.store.set_server_admin(target_user, set_admin_to) - return 200, {} + return HTTPStatus.OK, {} class UserMembershipRestServlet(RestServlet): @@ -816,7 +848,7 @@ class UserMembershipRestServlet(RestServlet): room_ids = await self.store.get_rooms_for_user(user_id) ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} - return 200, ret + return HTTPStatus.OK, ret class PushersRestServlet(RestServlet): @@ -845,7 +877,7 @@ class PushersRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -854,7 +886,10 @@ class PushersRestServlet(RestServlet): filtered_pushers = [p.as_dict() for p in pushers] - return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)} + return HTTPStatus.OK, { + "pushers": filtered_pushers, + "total": len(filtered_pushers), + } class UserTokenRestServlet(RestServlet): @@ -887,16 +922,22 @@ class UserTokenRestServlet(RestServlet): auth_user = requester.user if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be logged in as") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be logged in as" + ) body = parse_json_object_from_request(request, allow_empty_body=True) valid_until_ms = body.get("valid_until_ms") if valid_until_ms and not isinstance(valid_until_ms, int): - raise SynapseError(400, "'valid_until_ms' parameter must be an int") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int" + ) if auth_user.to_string() == user_id: - raise SynapseError(400, "Cannot use admin API to login as self") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Cannot use admin API to login as self" + ) token = await self.auth_handler.create_access_token_for_user_id( user_id=auth_user.to_string(), @@ -905,7 +946,7 @@ class UserTokenRestServlet(RestServlet): puppets_user_id=user_id, ) - return 200, {"access_token": token} + return HTTPStatus.OK, {"access_token": token} class ShadowBanRestServlet(RestServlet): @@ -947,11 +988,13 @@ class ShadowBanRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be shadow-banned") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" + ) await self.store.set_shadow_banned(UserID.from_string(user_id), True) - return 200, {} + return HTTPStatus.OK, {} async def on_DELETE( self, request: SynapseRequest, user_id: str @@ -959,11 +1002,13 @@ class ShadowBanRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be shadow-banned") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" + ) await self.store.set_shadow_banned(UserID.from_string(user_id), False) - return 200, {} + return HTTPStatus.OK, {} class RateLimitRestServlet(RestServlet): @@ -995,7 +1040,7 @@ class RateLimitRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -1016,7 +1061,7 @@ class RateLimitRestServlet(RestServlet): else: ret = {} - return 200, ret + return HTTPStatus.OK, ret async def on_POST( self, request: SynapseRequest, user_id: str @@ -1024,7 +1069,9 @@ class RateLimitRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be ratelimited") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" + ) if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -1036,14 +1083,14 @@ class RateLimitRestServlet(RestServlet): if not isinstance(messages_per_second, int) or messages_per_second < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (messages_per_second,), errcode=Codes.INVALID_PARAM, ) if not isinstance(burst_count, int) or burst_count < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (burst_count,), errcode=Codes.INVALID_PARAM, ) @@ -1059,7 +1106,7 @@ class RateLimitRestServlet(RestServlet): "burst_count": ratelimit.burst_count, } - return 200, ret + return HTTPStatus.OK, ret async def on_DELETE( self, request: SynapseRequest, user_id: str @@ -1067,11 +1114,13 @@ class RateLimitRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be ratelimited") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" + ) if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") await self.store.delete_ratelimit_for_user(user_id) - return 200, {} + return HTTPStatus.OK, {} -- cgit 1.5.1 From 379f2650cf875f50c59524147ec0e33cfd5ef60c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 30 Nov 2021 11:33:33 -0500 Subject: Bundle relations of relations into the `/relations` result. (#11284) Per updates to MSC2675 which now states that bundled aggregations should be included from the `/relations` endpoint. --- changelog.d/11284.feature | 1 + synapse/events/utils.py | 8 +++ synapse/rest/client/relations.py | 9 +-- tests/rest/client/test_relations.py | 118 ++++++++++++++++++++++++++++++++++++ 4 files changed, 130 insertions(+), 6 deletions(-) create mode 100644 changelog.d/11284.feature (limited to 'synapse/rest') diff --git a/changelog.d/11284.feature b/changelog.d/11284.feature new file mode 100644 index 0000000000..cbaa5a988c --- /dev/null +++ b/changelog.d/11284.feature @@ -0,0 +1 @@ +When returning relation events from the `/relations` API, bundle any relations of those relations into the result, per updates to [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/synapse/events/utils.py b/synapse/events/utils.py index e5967c995e..05219a9dd0 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -435,6 +435,14 @@ class EventClientSerializer: serialized_event: The serialized event which may be modified. """ + # Do not bundle relations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + return + event_id = event.event_id # The bundled relations to include. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 45e9f1dd90..b1a3304849 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -230,12 +230,9 @@ class RelationPaginationServlet(RestServlet): original_event = await self._event_serializer.serialize_event( event, now, bundle_relations=False ) - # Similarly, we don't allow relations to be applied to relations, so we - # return the original relations without any aggregations on top of them - # here. - serialized_events = await self._event_serializer.serialize_events( - events, now, bundle_relations=False - ) + # The relations returned for the requested event do include their + # bundled relations. + serialized_events = await self._event_serializer.serialize_events(events, now) return_value = pagination_chunk.to_dict() return_value["chunk"] = serialized_events diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index eb10d43217..b494da5138 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -526,6 +526,74 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) + def test_aggregation_get_event_for_annotation(self): + """Test that annotations do not get bundled relations included + when directly requested. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + annotation_id = channel.json_body["event_id"] + + # Annotate the annotation. + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{annotation_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) + + def test_aggregation_get_event_for_thread(self): + """Test that threads get bundled relations included when directly requested.""" + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + thread_id = channel.json_body["event_id"] + + # Annotate the annotation. + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{thread_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + + # It should also be included when the entire thread is requested. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1) + + thread_message = channel.json_body["chunk"][0] + self.assertEquals( + thread_message["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + def test_edit(self): """Test that a simple edit works.""" @@ -672,6 +740,56 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) + def test_edit_edit(self): + """Test that an edit cannot be edited.""" + new_body = {"msgtype": "m.text", "body": "Initial edit"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": new_body, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + edit_event_id = channel.json_body["event_id"] + + # Edit the edit event. + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "foo", + "m.new_content": {"msgtype": "m.text", "body": "Ignored edit"}, + }, + parent_id=edit_event_id, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Request the original event. + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + # The edit to the edit should be ignored. + self.assertEquals(channel.json_body["content"], new_body) + + # The relations information should not include the edit to the edit. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + def test_relations_redaction_redacts_edits(self): """Test that edits of an event are redacted when the original event is redacted. -- cgit 1.5.1 From a265fbd397ae72b2d3ea4c9310591ff1d0f3e05c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Dec 2021 07:25:58 -0500 Subject: Register the login redirect endpoint for v3. (#11451) As specified for Matrix v1.1. --- changelog.d/11451.bugfix | 1 + synapse/rest/client/login.py | 2 +- synapse/rest/client/room.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/11451.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/11451.bugfix b/changelog.d/11451.bugfix new file mode 100644 index 0000000000..960714d0f9 --- /dev/null +++ b/changelog.d/11451.bugfix @@ -0,0 +1 @@ +Add support for the `/_matrix/client/v3/login/sso/redirect/{idpId}` API from Matrix v1.1. This endpoint was overlooked when support for v3 endpoints was added in v1.48.0rc1. diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 09f378f919..a66ee4fb3d 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -513,7 +513,7 @@ class SsoRedirectServlet(RestServlet): re.compile( "^" + CLIENT_API_PREFIX - + "/r0/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$" + + "/(r0|v3)/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$" ) ] diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 73d0f7c950..99f303c88e 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1138,12 +1138,12 @@ class RoomSpaceSummaryRestServlet(RestServlet): class RoomHierarchyRestServlet(RestServlet): - PATTERNS = [ + PATTERNS = ( re.compile( "^/_matrix/client/(v1|unstable/org.matrix.msc2946)" "/rooms/(?P[^/]*)/hierarchy$" ), - ] + ) def __init__(self, hs: "HomeServer"): super().__init__() -- cgit 1.5.1 From a6f1a3abecf8e8fd3e1bff439a06b853df18f194 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 2 Dec 2021 01:02:20 -0600 Subject: Add MSC3030 experimental client and federation API endpoints to get the closest event to a given timestamp (#9445) MSC3030: https://github.com/matrix-org/matrix-doc/pull/3030 Client API endpoint. This will also go and fetch from the federation API endpoint if unable to find an event locally or we found an extremity with possibly a closer event we don't know about. ``` GET /_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir= { "event_id": ... "origin_server_ts": ... } ``` Federation API endpoint: ``` GET /_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/?ts=&dir= { "event_id": ... "origin_server_ts": ... } ``` Co-authored-by: Erik Johnston --- changelog.d/9445.feature | 1 + synapse/config/experimental.py | 3 + synapse/federation/federation_client.py | 77 +++++++++ synapse/federation/federation_server.py | 43 +++++ synapse/federation/transport/client.py | 36 ++++ synapse/federation/transport/server/__init__.py | 12 +- synapse/federation/transport/server/federation.py | 41 +++++ synapse/handlers/federation.py | 61 +++---- synapse/handlers/room.py | 144 ++++++++++++++++ synapse/http/servlet.py | 29 ++++ synapse/rest/client/room.py | 58 +++++++ synapse/server.py | 5 + synapse/storage/databases/main/events_worker.py | 195 ++++++++++++++++++++++ 13 files changed, 674 insertions(+), 31 deletions(-) create mode 100644 changelog.d/9445.feature (limited to 'synapse/rest') diff --git a/changelog.d/9445.feature b/changelog.d/9445.feature new file mode 100644 index 0000000000..6d12eea71f --- /dev/null +++ b/changelog.d/9445.feature @@ -0,0 +1 @@ +Add [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) experimental client and federation API endpoints to get the closest event to a given timestamp. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 8b098ad48d..d78a15097c 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -46,3 +46,6 @@ class ExperimentalConfig(Config): # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) + + # MSC3030 (Jump to date API endpoint) + self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index bc3f96c1fc..be1423da24 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1517,6 +1517,83 @@ class FederationClient(FederationBase): self._get_room_hierarchy_cache[(room_id, suggested_only)] = result return result + async def timestamp_to_event( + self, destination: str, room_id: str, timestamp: int, direction: str + ) -> "TimestampToEventResponse": + """ + Calls a remote federating server at `destination` asking for their + closest event to the given timestamp in the given direction. Also + validates the response to always return the expected keys or raises an + error. + + Args: + destination: Domain name of the remote homeserver + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + A parsed TimestampToEventResponse including the closest event_id + and origin_server_ts + + Raises: + Various exceptions when the request fails + InvalidResponseError when the response does not have the correct + keys or wrong types + """ + remote_response = await self.transport_layer.timestamp_to_event( + destination, room_id, timestamp, direction + ) + + if not isinstance(remote_response, dict): + raise InvalidResponseError( + "Response must be a JSON dictionary but received %r" % remote_response + ) + + try: + return TimestampToEventResponse.from_json_dict(remote_response) + except ValueError as e: + raise InvalidResponseError(str(e)) + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class TimestampToEventResponse: + """Typed response dictionary for the federation /timestamp_to_event endpoint""" + + event_id: str + origin_server_ts: int + + # the raw data, including the above keys + data: JsonDict + + @classmethod + def from_json_dict(cls, d: JsonDict) -> "TimestampToEventResponse": + """Parsed response from the federation /timestamp_to_event endpoint + + Args: + d: JSON object response to be parsed + + Raises: + ValueError if d does not the correct keys or they are the wrong types + """ + + event_id = d.get("event_id") + if not isinstance(event_id, str): + raise ValueError( + "Invalid response: 'event_id' must be a str but received %r" % event_id + ) + + origin_server_ts = d.get("origin_server_ts") + if not isinstance(origin_server_ts, int): + raise ValueError( + "Invalid response: 'origin_server_ts' must be a int but received %r" + % origin_server_ts + ) + + return cls(event_id, origin_server_ts, d) + @attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryEventResult: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8fbc75aa65..cce85526e7 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -110,6 +110,7 @@ class FederationServer(FederationBase): super().__init__(hs) self.handler = hs.get_federation_handler() + self.storage = hs.get_storage() self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() @@ -200,6 +201,48 @@ class FederationServer(FederationBase): return 200, res + async def on_timestamp_to_event_request( + self, origin: str, room_id: str, timestamp: int, direction: str + ) -> Tuple[int, Dict[str, Any]]: + """When we receive a federated `/timestamp_to_event` request, + handle all of the logic for validating and fetching the event. + + Args: + origin: The server we received the event from + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + Tuple indicating the response status code and dictionary response + body including `event_id`. + """ + with (await self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, room_id) + + # We only try to fetch data from the local database + event_id = await self.store.get_event_id_for_timestamp( + room_id, timestamp, direction + ) + if event_id: + event = await self.store.get_event( + event_id, allow_none=False, allow_rejected=False + ) + + return 200, { + "event_id": event_id, + "origin_server_ts": event.origin_server_ts, + } + + raise SynapseError( + 404, + "Unable to find event from %s in direction %s" % (timestamp, direction), + errcode=Codes.NOT_FOUND, + ) + async def on_incoming_transaction( self, origin: str, diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index fe29bcfd4b..d1f4be641d 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -148,6 +148,42 @@ class TransportLayerClient: destination, path=path, args=args, try_trailing_slash_on_400=True ) + @log_function + async def timestamp_to_event( + self, destination: str, room_id: str, timestamp: int, direction: str + ) -> Union[JsonDict, List]: + """ + Calls a remote federating server at `destination` asking for their + closest event to the given timestamp in the given direction. + + Args: + destination: Domain name of the remote homeserver + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + Response dict received from the remote homeserver. + + Raises: + Various exceptions when the request fails + """ + path = _create_path( + FEDERATION_UNSTABLE_PREFIX, + "/org.matrix.msc3030/timestamp_to_event/%s", + room_id, + ) + + args = {"ts": [str(timestamp)], "dir": [direction]} + + remote_response = await self.client.get_json( + destination, path=path, args=args, try_trailing_slash_on_400=True + ) + + return remote_response + @log_function async def send_transaction( self, diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index c32539bf5a..abcb8728f5 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -22,7 +22,10 @@ from synapse.federation.transport.server._base import ( Authenticator, BaseFederationServlet, ) -from synapse.federation.transport.server.federation import FEDERATION_SERVLET_CLASSES +from synapse.federation.transport.server.federation import ( + FEDERATION_SERVLET_CLASSES, + FederationTimestampLookupServlet, +) from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES from synapse.federation.transport.server.groups_server import ( GROUP_SERVER_SERVLET_CLASSES, @@ -324,6 +327,13 @@ def register_servlets( ) for servletclass in DEFAULT_SERVLET_GROUPS[servlet_group]: + # Only allow the `/timestamp_to_event` servlet if msc3030 is enabled + if ( + servletclass == FederationTimestampLookupServlet + and not hs.config.experimental.msc3030_enabled + ): + continue + servletclass( hs=hs, authenticator=authenticator, diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 66e915228c..77bfd88ad0 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -174,6 +174,46 @@ class FederationBackfillServlet(BaseFederationServerServlet): return await self.handler.on_backfill_request(origin, room_id, versions, limit) +class FederationTimestampLookupServlet(BaseFederationServerServlet): + """ + API endpoint to fetch the `event_id` of the closest event to the given + timestamp (`ts` query parameter) in the given direction (`dir` query + parameter). + + Useful for other homeservers when they're unable to find an event locally. + + `ts` is a timestamp in milliseconds where we will find the closest event in + the given direction. + + `dir` can be `f` or `b` to indicate forwards and backwards in time from the + given timestamp. + + GET /_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/?ts=&dir= + { + "event_id": ... + } + """ + + PATH = "/timestamp_to_event/(?P[^/]*)/?" + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3030" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + timestamp = parse_integer_from_args(query, "ts", required=True) + direction = parse_string_from_args( + query, "dir", default="f", allowed_values=["f", "b"], required=True + ) + + return await self.handler.on_timestamp_to_event_request( + origin, room_id, timestamp, direction + ) + + class FederationQueryServlet(BaseFederationServerServlet): PATH = "/query/(?P[^/]*)" @@ -683,6 +723,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationStateV1Servlet, FederationStateIdsServlet, FederationBackfillServlet, + FederationTimestampLookupServlet, FederationQueryServlet, FederationMakeJoinServlet, FederationMakeLeaveServlet, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3112cc88b1..1ea837d082 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -68,6 +68,37 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: + """Get joined domains from state + + Args: + state: State map from type/state key to event. + + Returns: + Returns a list of servers with the lowest depth of their joins. + Sorted by lowest depth first. + """ + joined_users = [ + (state_key, int(event.depth)) + for (e_type, state_key), event in state.items() + if e_type == EventTypes.Member and event.membership == Membership.JOIN + ] + + joined_domains: Dict[str, int] = {} + for u, d in joined_users: + try: + dom = get_domain_from_id(u) + old_d = joined_domains.get(dom) + if old_d: + joined_domains[dom] = min(d, old_d) + else: + joined_domains[dom] = d + except Exception: + pass + + return sorted(joined_domains.items(), key=lambda d: d[1]) + + class FederationHandler: """Handles general incoming federation requests @@ -268,36 +299,6 @@ class FederationHandler: curr_state = await self.state_handler.get_current_state(room_id) - def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: - """Get joined domains from state - - Args: - state: State map from type/state key to event. - - Returns: - Returns a list of servers with the lowest depth of their joins. - Sorted by lowest depth first. - """ - joined_users = [ - (state_key, int(event.depth)) - for (e_type, state_key), event in state.items() - if e_type == EventTypes.Member and event.membership == Membership.JOIN - ] - - joined_domains: Dict[str, int] = {} - for u, d in joined_users: - try: - dom = get_domain_from_id(u) - old_d = joined_domains.get(dom) - if old_d: - joined_domains[dom] = min(d, old_d) - else: - joined_domains[dom] = d - except Exception: - pass - - return sorted(joined_domains.items(), key=lambda d: d[1]) - curr_domains = get_domains_from_state(curr_state) likely_domains = [ diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 88053f9869..2bcdf32dcc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -46,6 +46,7 @@ from synapse.api.constants import ( from synapse.api.errors import ( AuthError, Codes, + HttpResponseException, LimitExceededError, NotFoundError, StoreError, @@ -56,6 +57,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents +from synapse.federation.federation_client import InvalidResponseError +from synapse.handlers.federation import get_domains_from_state from synapse.rest.admin._base import assert_user_is_admin from synapse.storage.state import StateFilter from synapse.streams import EventSource @@ -1220,6 +1223,147 @@ class RoomContextHandler: return results +class TimestampLookupHandler: + def __init__(self, hs: "HomeServer"): + self.server_name = hs.hostname + self.store = hs.get_datastore() + self.state_handler = hs.get_state_handler() + self.federation_client = hs.get_federation_client() + + async def get_event_for_timestamp( + self, + requester: Requester, + room_id: str, + timestamp: int, + direction: str, + ) -> Tuple[str, int]: + """Find the closest event to the given timestamp in the given direction. + If we can't find an event locally or the event we have locally is next to a gap, + it will ask other federated homeservers for an event. + + Args: + requester: The user making the request according to the access token + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + A tuple containing the `event_id` closest to the given timestamp in + the given direction and the `origin_server_ts`. + + Raises: + SynapseError if unable to find any event locally in the given direction + """ + + local_event_id = await self.store.get_event_id_for_timestamp( + room_id, timestamp, direction + ) + logger.debug( + "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s", + local_event_id, + timestamp, + ) + + # Check for gaps in the history where events could be hiding in between + # the timestamp given and the event we were able to find locally + is_event_next_to_backward_gap = False + is_event_next_to_forward_gap = False + if local_event_id: + local_event = await self.store.get_event( + local_event_id, allow_none=False, allow_rejected=False + ) + + if direction == "f": + # We only need to check for a backward gap if we're looking forwards + # to ensure there is nothing in between. + is_event_next_to_backward_gap = ( + await self.store.is_event_next_to_backward_gap(local_event) + ) + elif direction == "b": + # We only need to check for a forward gap if we're looking backwards + # to ensure there is nothing in between + is_event_next_to_forward_gap = ( + await self.store.is_event_next_to_forward_gap(local_event) + ) + + # If we found a gap, we should probably ask another homeserver first + # about more history in between + if ( + not local_event_id + or is_event_next_to_backward_gap + or is_event_next_to_forward_gap + ): + logger.debug( + "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s which is next to a gap in event history so we're asking other homeservers first", + local_event_id, + timestamp, + ) + + # Find other homeservers from the given state in the room + curr_state = await self.state_handler.get_current_state(room_id) + curr_domains = get_domains_from_state(curr_state) + likely_domains = [ + domain for domain, depth in curr_domains if domain != self.server_name + ] + + # Loop through each homeserver candidate until we get a succesful response + for domain in likely_domains: + try: + remote_response = await self.federation_client.timestamp_to_event( + domain, room_id, timestamp, direction + ) + logger.debug( + "get_event_for_timestamp: response from domain(%s)=%s", + domain, + remote_response, + ) + + # TODO: Do we want to persist this as an extremity? + # TODO: I think ideally, we would try to backfill from + # this event and run this whole + # `get_event_for_timestamp` function again to make sure + # they didn't give us an event from their gappy history. + remote_event_id = remote_response.event_id + origin_server_ts = remote_response.origin_server_ts + + # Only return the remote event if it's closer than the local event + if not local_event or ( + abs(origin_server_ts - timestamp) + < abs(local_event.origin_server_ts - timestamp) + ): + return remote_event_id, origin_server_ts + except (HttpResponseException, InvalidResponseError) as ex: + # Let's not put a high priority on some other homeserver + # failing to respond or giving a random response + logger.debug( + "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", + domain, + type(ex).__name__, + ex, + ex.args, + ) + except Exception as ex: + # But we do want to see some exceptions in our code + logger.warning( + "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", + domain, + type(ex).__name__, + ex, + ex.args, + ) + + if not local_event_id: + raise SynapseError( + 404, + "Unable to find event from %s in direction %s" % (timestamp, direction), + errcode=Codes.NOT_FOUND, + ) + + return local_event_id, local_event.origin_server_ts + + class RoomEventSource(EventSource[RoomStreamToken, EventBase]): def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 91ba93372c..6dd9b9ad03 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -79,6 +79,35 @@ def parse_integer( return parse_integer_from_args(args, name, default, required) +@overload +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[int] = None, +) -> Optional[int]: + ... + + +@overload +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + *, + required: Literal[True], +) -> int: + ... + + +@overload +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[int] = None, + required: bool = False, +) -> Optional[int]: + ... + + def parse_integer_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 99f303c88e..3598967be0 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1070,6 +1070,62 @@ def register_txn_path( ) +class TimestampLookupRestServlet(RestServlet): + """ + API endpoint to fetch the `event_id` of the closest event to the given + timestamp (`ts` query parameter) in the given direction (`dir` query + parameter). + + Useful for cases like jump to date so you can start paginating messages from + a given date in the archive. + + `ts` is a timestamp in milliseconds where we will find the closest event in + the given direction. + + `dir` can be `f` or `b` to indicate forwards and backwards in time from the + given timestamp. + + GET /_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir= + { + "event_id": ... + } + """ + + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc3030" + "/rooms/(?P[^/]*)/timestamp_to_event$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request) + await self._auth.check_user_in_room(room_id, requester.user.to_string()) + + timestamp = parse_integer(request, "ts", required=True) + direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) + + ( + event_id, + origin_server_ts, + ) = await self.timestamp_lookup_handler.get_event_for_timestamp( + requester, room_id, timestamp, direction + ) + + return 200, { + "event_id": event_id, + "origin_server_ts": origin_server_ts, + } + + class RoomSpaceSummaryRestServlet(RestServlet): PATTERNS = ( re.compile( @@ -1239,6 +1295,8 @@ def register_servlets( RoomAliasListServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server) + if hs.config.experimental.msc3030_enabled: + TimestampLookupRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if not is_worker: diff --git a/synapse/server.py b/synapse/server.py index 877eba6c08..185e40e4da 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -97,6 +97,7 @@ from synapse.handlers.room import ( RoomContextHandler, RoomCreationHandler, RoomShutdownHandler, + TimestampLookupHandler, ) from synapse.handlers.room_batch import RoomBatchHandler from synapse.handlers.room_list import RoomListHandler @@ -728,6 +729,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_room_context_handler(self) -> RoomContextHandler: return RoomContextHandler(self) + @cache_in_self + def get_timestamp_lookup_handler(self) -> TimestampLookupHandler: + return TimestampLookupHandler(self) + @cache_in_self def get_registration_handler(self) -> RegistrationHandler: return RegistrationHandler(self) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 4cefc0a07e..fd19674f93 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1762,3 +1762,198 @@ class EventsWorkerStore(SQLBaseStore): "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn, ) + + async def is_event_next_to_backward_gap(self, event: EventBase) -> bool: + """Check if the given event is next to a backward gap of missing events. + A(False)--->B(False)--->C(True)---> + + Args: + room_id: room where the event lives + event_id: event to check + + Returns: + Boolean indicating whether it's an extremity + """ + + def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: + # If the event in question has any of its prev_events listed as a + # backward extremity, it's next to a gap. + # + # We can't just check the backward edges in `event_edges` because + # when we persist events, we will also record the prev_events as + # edges to the event in question regardless of whether we have those + # prev_events yet. We need to check whether those prev_events are + # backward extremities, also known as gaps, that need to be + # backfilled. + backward_extremity_query = """ + SELECT 1 FROM event_backward_extremities + WHERE + room_id = ? + AND %s + LIMIT 1 + """ + + # If the event in question is a backward extremity or has any of its + # prev_events listed as a backward extremity, it's next to a + # backward gap. + clause, args = make_in_list_sql_clause( + self.database_engine, + "event_id", + [event.event_id] + list(event.prev_event_ids()), + ) + + txn.execute(backward_extremity_query % (clause,), [event.room_id] + args) + backward_extremities = txn.fetchall() + + # We consider any backward extremity as a backward gap + if len(backward_extremities): + return True + + return False + + return await self.db_pool.runInteraction( + "is_event_next_to_backward_gap_txn", + is_event_next_to_backward_gap_txn, + ) + + async def is_event_next_to_forward_gap(self, event: EventBase) -> bool: + """Check if the given event is next to a forward gap of missing events. + The gap in front of the latest events is not considered a gap. + A(False)--->B(False)--->C(False)---> + A(False)--->B(False)---> --->D(True)--->E(False) + + Args: + room_id: room where the event lives + event_id: event to check + + Returns: + Boolean indicating whether it's an extremity + """ + + def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: + # If the event in question is a forward extremity, we will just + # consider any potential forward gap as not a gap since it's one of + # the latest events in the room. + # + # `event_forward_extremities` does not include backfilled or outlier + # events so we can't rely on it to find forward gaps. We can only + # use it to determine whether a message is the latest in the room. + # + # We can't combine this query with the `forward_edge_query` below + # because if the event in question has no forward edges (isn't + # referenced by any other event's prev_events) but is in + # `event_forward_extremities`, we don't want to return 0 rows and + # say it's next to a gap. + forward_extremity_query = """ + SELECT 1 FROM event_forward_extremities + WHERE + room_id = ? + AND event_id = ? + LIMIT 1 + """ + + # Check to see whether the event in question is already referenced + # by another event. If we don't see any edges, we're next to a + # forward gap. + forward_edge_query = """ + SELECT 1 FROM event_edges + /* Check to make sure the event referencing our event in question is not rejected */ + LEFT JOIN rejections ON event_edges.event_id == rejections.event_id + WHERE + event_edges.room_id = ? + AND event_edges.prev_event_id = ? + /* It's not a valid edge if the event referencing our event in + * question is rejected. + */ + AND rejections.event_id IS NULL + LIMIT 1 + """ + + # We consider any forward extremity as the latest in the room and + # not a forward gap. + # + # To expand, even though there is technically a gap at the front of + # the room where the forward extremities are, we consider those the + # latest messages in the room so asking other homeservers for more + # is useless. The new latest messages will just be federated as + # usual. + txn.execute(forward_extremity_query, (event.room_id, event.event_id)) + forward_extremities = txn.fetchall() + if len(forward_extremities): + return False + + # If there are no forward edges to the event in question (another + # event hasn't referenced this event in their prev_events), then we + # assume there is a forward gap in the history. + txn.execute(forward_edge_query, (event.room_id, event.event_id)) + forward_edges = txn.fetchall() + if not len(forward_edges): + return True + + return False + + return await self.db_pool.runInteraction( + "is_event_next_to_gap_txn", + is_event_next_to_gap_txn, + ) + + async def get_event_id_for_timestamp( + self, room_id: str, timestamp: int, direction: str + ) -> Optional[str]: + """Find the closest event to the given timestamp in the given direction. + + Args: + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + The closest event_id otherwise None if we can't find any event in + the given direction. + """ + + sql_template = """ + SELECT event_id FROM events + LEFT JOIN rejections USING (event_id) + WHERE + origin_server_ts %s ? + AND room_id = ? + /* Make sure event is not rejected */ + AND rejections.event_id IS NULL + ORDER BY origin_server_ts %s + LIMIT 1; + """ + + def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: + if direction == "b": + # Find closest event *before* a given timestamp. We use descending + # (which gives values largest to smallest) because we want the + # largest possible timestamp *before* the given timestamp. + comparison_operator = "<=" + order = "DESC" + else: + # Find closest event *after* a given timestamp. We use ascending + # (which gives values smallest to largest) because we want the + # closest possible timestamp *after* the given timestamp. + comparison_operator = ">=" + order = "ASC" + + txn.execute( + sql_template % (comparison_operator, order), (timestamp, room_id) + ) + row = txn.fetchone() + if row: + (event_id,) = row + return event_id + + return None + + if direction not in ("f", "b"): + raise ValueError("Unknown direction: %s" % (direction,)) + + return await self.db_pool.runInteraction( + "get_event_id_for_timestamp_txn", + get_event_id_for_timestamp_txn, + ) -- cgit 1.5.1 From 858d80bf0f9f656a03992794874081b806e49222 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Thu, 2 Dec 2021 16:05:24 +0000 Subject: Fix media repository failing when media store path contains symlinks (#11446) --- changelog.d/11446.bugfix | 1 + synapse/rest/media/v1/filepath.py | 115 +++++++++++++++++++++-------------- tests/rest/media/v1/test_filepath.py | 109 ++++++++++++++++++++++++++++++++- 3 files changed, 180 insertions(+), 45 deletions(-) create mode 100644 changelog.d/11446.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/11446.bugfix b/changelog.d/11446.bugfix new file mode 100644 index 0000000000..fa5e055d50 --- /dev/null +++ b/changelog.d/11446.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.47.1 where the media repository would fail to work if the media store path contained any symbolic links. diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index c0e15c6513..1f6441c412 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -43,47 +43,75 @@ GetPathMethod = TypeVar( ) -def _wrap_with_jail_check(func: GetPathMethod) -> GetPathMethod: +def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]: """Wraps a path-returning method to check that the returned path(s) do not escape the media store directory. + The path-returning method may return either a single path, or a list of paths. + The check is not expected to ever fail, unless `func` is missing a call to `_validate_path_component`, or `_validate_path_component` is buggy. Args: - func: The `MediaFilePaths` method to wrap. The method may return either a single - path, or a list of paths. Returned paths may be either absolute or relative. + relative: A boolean indicating whether the wrapped method returns paths relative + to the media store directory. Returns: - The method, wrapped with a check to ensure that the returned path(s) lie within - the media store directory. Raises a `ValueError` if the check fails. + A method which will wrap a path-returning method, adding a check to ensure that + the returned path(s) lie within the media store directory. The check will raise + a `ValueError` if it fails. """ - @functools.wraps(func) - def _wrapped( - self: "MediaFilePaths", *args: Any, **kwargs: Any - ) -> Union[str, List[str]]: - path_or_paths = func(self, *args, **kwargs) - - if isinstance(path_or_paths, list): - paths_to_check = path_or_paths - else: - paths_to_check = [path_or_paths] - - for path in paths_to_check: - # path may be an absolute or relative path, depending on the method being - # wrapped. When "appending" an absolute path, `os.path.join` discards the - # previous path, which is desired here. - normalized_path = os.path.normpath(os.path.join(self.real_base_path, path)) - if ( - os.path.commonpath([normalized_path, self.real_base_path]) - != self.real_base_path - ): - raise ValueError(f"Invalid media store path: {path!r}") - - return path_or_paths - - return cast(GetPathMethod, _wrapped) + def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod: + @functools.wraps(func) + def _wrapped( + self: "MediaFilePaths", *args: Any, **kwargs: Any + ) -> Union[str, List[str]]: + path_or_paths = func(self, *args, **kwargs) + + if isinstance(path_or_paths, list): + paths_to_check = path_or_paths + else: + paths_to_check = [path_or_paths] + + for path in paths_to_check: + # Construct the path that will ultimately be used. + # We cannot guess whether `path` is relative to the media store + # directory, since the media store directory may itself be a relative + # path. + if relative: + path = os.path.join(self.base_path, path) + normalized_path = os.path.normpath(path) + + # Now that `normpath` has eliminated `../`s and `./`s from the path, + # `os.path.commonpath` can be used to check whether it lies within the + # media store directory. + if ( + os.path.commonpath([normalized_path, self.normalized_base_path]) + != self.normalized_base_path + ): + # The path resolves to outside the media store directory, + # or `self.base_path` is `.`, which is an unlikely configuration. + raise ValueError(f"Invalid media store path: {path!r}") + + # Note that `os.path.normpath`/`abspath` has a subtle caveat: + # `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a + # different path if `a/b/c` is a symlink. That is, the check above is + # not perfect and may allow a certain restricted subset of untrustworthy + # paths through. Since the check above is secondary to the main + # `_validate_path_component` checks, it's less important for it to be + # perfect. + # + # As an alternative, `os.path.realpath` will resolve symlinks, but + # proves problematic if there are symlinks inside the media store. + # eg. if `url_store/` is symlinked to elsewhere, its canonical path + # won't match that of the main media store directory. + + return path_or_paths + + return cast(GetPathMethod, _wrapped) + + return _wrap_with_jail_check_inner ALLOWED_CHARACTERS = set( @@ -127,9 +155,7 @@ class MediaFilePaths: def __init__(self, primary_base_path: str): self.base_path = primary_base_path - - # The media store directory, with all symlinks resolved. - self.real_base_path = os.path.realpath(primary_base_path) + self.normalized_base_path = os.path.normpath(self.base_path) # Refuse to initialize if paths cannot be validated correctly for the current # platform. @@ -140,7 +166,7 @@ class MediaFilePaths: # for certain homeservers there, since ":"s aren't allowed in paths. assert os.name == "posix" - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def local_media_filepath_rel(self, media_id: str) -> str: return os.path.join( "local_content", @@ -151,7 +177,7 @@ class MediaFilePaths: local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def local_media_thumbnail_rel( self, media_id: str, width: int, height: int, content_type: str, method: str ) -> str: @@ -167,7 +193,7 @@ class MediaFilePaths: local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=False) def local_media_thumbnail_dir(self, media_id: str) -> str: """ Retrieve the local store path of thumbnails of a given media_id @@ -185,7 +211,7 @@ class MediaFilePaths: _validate_path_component(media_id[4:]), ) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: return os.path.join( "remote_content", @@ -197,7 +223,7 @@ class MediaFilePaths: remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def remote_media_thumbnail_rel( self, server_name: str, @@ -223,7 +249,7 @@ class MediaFilePaths: # Legacy path that was used to store thumbnails previously. # Should be removed after some time, when most of the thumbnails are stored # using the new path. - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def remote_media_thumbnail_rel_legacy( self, server_name: str, file_id: str, width: int, height: int, content_type: str ) -> str: @@ -238,6 +264,7 @@ class MediaFilePaths: _validate_path_component(file_name), ) + @_wrap_with_jail_check(relative=False) def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: return os.path.join( self.base_path, @@ -248,7 +275,7 @@ class MediaFilePaths: _validate_path_component(file_id[4:]), ) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def url_cache_filepath_rel(self, media_id: str) -> str: if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form @@ -268,7 +295,7 @@ class MediaFilePaths: url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=False) def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): @@ -290,7 +317,7 @@ class MediaFilePaths: ), ] - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def url_cache_thumbnail_rel( self, media_id: str, width: int, height: int, content_type: str, method: str ) -> str: @@ -318,7 +345,7 @@ class MediaFilePaths: url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -341,7 +368,7 @@ class MediaFilePaths: url_cache_thumbnail_directory_rel ) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=False) def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id thumbnails" # Media id is of the form diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py index 8fe94f7d85..913bc530aa 100644 --- a/tests/rest/media/v1/test_filepath.py +++ b/tests/rest/media/v1/test_filepath.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import os from typing import Iterable -from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.rest.media.v1.filepath import MediaFilePaths, _wrap_with_jail_check from tests import unittest @@ -486,3 +487,109 @@ class MediaFilePathsTestCase(unittest.TestCase): f"{value!r} unexpectedly passed validation: " f"{method} returned {path_or_list!r}" ) + + +class MediaFilePathsJailTestCase(unittest.TestCase): + def _check_relative_path(self, filepaths: MediaFilePaths, path: str) -> None: + """Passes a relative path through the jail check. + + Args: + filepaths: The `MediaFilePaths` instance. + path: A path relative to the media store directory. + + Raises: + ValueError: If the jail check fails. + """ + + @_wrap_with_jail_check(relative=True) + def _make_relative_path(self: MediaFilePaths, path: str) -> str: + return path + + _make_relative_path(filepaths, path) + + def _check_absolute_path(self, filepaths: MediaFilePaths, path: str) -> None: + """Passes an absolute path through the jail check. + + Args: + filepaths: The `MediaFilePaths` instance. + path: A path relative to the media store directory. + + Raises: + ValueError: If the jail check fails. + """ + + @_wrap_with_jail_check(relative=False) + def _make_absolute_path(self: MediaFilePaths, path: str) -> str: + return os.path.join(self.base_path, path) + + _make_absolute_path(filepaths, path) + + def test_traversal_inside(self) -> None: + """Test the jail check for paths that stay within the media directory.""" + # Despite the `../`s, these paths still lie within the media directory and it's + # expected for the jail check to allow them through. + # These paths ought to trip the other checks in place and should never be + # returned. + filepaths = MediaFilePaths("/media_store") + path = "url_cache/2020-01-02/../../GerZNDnDZVjsOtar" + self._check_relative_path(filepaths, path) + self._check_absolute_path(filepaths, path) + + def test_traversal_outside(self) -> None: + """Test that the jail check fails for paths that escape the media directory.""" + filepaths = MediaFilePaths("/media_store") + path = "url_cache/2020-01-02/../../../GerZNDnDZVjsOtar" + with self.assertRaises(ValueError): + self._check_relative_path(filepaths, path) + with self.assertRaises(ValueError): + self._check_absolute_path(filepaths, path) + + def test_traversal_reentry(self) -> None: + """Test the jail check for paths that exit and re-enter the media directory.""" + # These paths lie outside the media directory if it is a symlink, and inside + # otherwise. Ideally the check should fail, but this proves difficult. + # This test documents the behaviour for this edge case. + # These paths ought to trip the other checks in place and should never be + # returned. + filepaths = MediaFilePaths("/media_store") + path = "url_cache/2020-01-02/../../../media_store/GerZNDnDZVjsOtar" + self._check_relative_path(filepaths, path) + self._check_absolute_path(filepaths, path) + + def test_symlink(self) -> None: + """Test that a symlink does not cause the jail check to fail.""" + media_store_path = self.mktemp() + + # symlink the media store directory + os.symlink("/mnt/synapse/media_store", media_store_path) + + # Test that relative and absolute paths don't trip the check + # NB: `media_store_path` is a relative path + filepaths = MediaFilePaths(media_store_path) + self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + + filepaths = MediaFilePaths(os.path.abspath(media_store_path)) + self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + + def test_symlink_subdirectory(self) -> None: + """Test that a symlinked subdirectory does not cause the jail check to fail.""" + media_store_path = self.mktemp() + os.mkdir(media_store_path) + + # symlink `url_cache/` + os.symlink( + "/mnt/synapse/media_store_url_cache", + os.path.join(media_store_path, "url_cache"), + ) + + # Test that relative and absolute paths don't trip the check + # NB: `media_store_path` is a relative path + filepaths = MediaFilePaths(media_store_path) + self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + + filepaths = MediaFilePaths(os.path.abspath(media_store_path)) + self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") -- cgit 1.5.1 From 494ebd7347ba52d702802fba4c3bb13e7bfbc2cf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 6 Dec 2021 10:51:15 -0500 Subject: Include bundled aggregations in /sync and related fixes (#11478) Due to updates to MSC2675 this includes a few fixes: * Include bundled aggregations for /sync. * Do not include bundled aggregations for /initialSync and /events. * Do not bundle aggregations for state events. * Clarifies comments and variable names. --- changelog.d/11478.bugfix | 1 + synapse/events/utils.py | 58 ++++++++++------ synapse/handlers/events.py | 5 +- synapse/handlers/initial_sync.py | 30 ++++++-- synapse/handlers/message.py | 8 +-- synapse/rest/admin/rooms.py | 13 +--- synapse/rest/client/relations.py | 9 ++- synapse/rest/client/room.py | 5 +- synapse/rest/client/sync.py | 6 +- tests/rest/client/test_relations.py | 135 +++++++++++++++++++++++++----------- 10 files changed, 169 insertions(+), 101 deletions(-) create mode 100644 changelog.d/11478.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/11478.bugfix b/changelog.d/11478.bugfix new file mode 100644 index 0000000000..5f02636f50 --- /dev/null +++ b/changelog.d/11478.bugfix @@ -0,0 +1 @@ +Include bundled relation aggregations during a limited `/sync` request, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 05219a9dd0..84ef69df67 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -306,6 +306,7 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: def serialize_event( e: Union[JsonDict, EventBase], time_now_ms: int, + *, as_client_event: bool = True, event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1, token_id: Optional[str] = None, @@ -393,7 +394,8 @@ class EventClientSerializer: self, event: Union[JsonDict, EventBase], time_now: int, - bundle_relations: bool = True, + *, + bundle_aggregations: bool = True, **kwargs: Any, ) -> JsonDict: """Serializes a single event. @@ -401,8 +403,9 @@ class EventClientSerializer: Args: event: The event being serialized. time_now: The current time in milliseconds - bundle_relations: Whether to include the bundled relations for this - event. + bundle_aggregations: Whether to include the bundled aggregations for this + event. Only applies to non-state events. (State events never include + bundled aggregations.) **kwargs: Arguments to pass to `serialize_event` Returns: @@ -414,20 +417,27 @@ class EventClientSerializer: serialized_event = serialize_event(event, time_now, **kwargs) - # If MSC1849 is enabled then we need to look if there are any relations - # we need to bundle in with the event. - # Do not bundle relations if the event has been redacted - if not event.internal_metadata.is_redacted() and ( - self._msc1849_enabled and bundle_relations + # Check if there are any bundled aggregations to include with the event. + # + # Do not bundle aggregations if any of the following at true: + # + # * Support is disabled via the configuration or the caller. + # * The event is a state event. + # * The event has been redacted. + if ( + self._msc1849_enabled + and bundle_aggregations + and not event.is_state() + and not event.internal_metadata.is_redacted() ): - await self._injected_bundled_relations(event, time_now, serialized_event) + await self._injected_bundled_aggregations(event, time_now, serialized_event) return serialized_event - async def _injected_bundled_relations( + async def _injected_bundled_aggregations( self, event: EventBase, time_now: int, serialized_event: JsonDict ) -> None: - """Potentially injects bundled relations into the unsigned portion of the serialized event. + """Potentially injects bundled aggregations into the unsigned portion of the serialized event. Args: event: The event being serialized. @@ -435,7 +445,7 @@ class EventClientSerializer: serialized_event: The serialized event which may be modified. """ - # Do not bundle relations for an event which represents an edit or an + # Do not bundle aggregations for an event which represents an edit or an # annotation. It does not make sense for them to have related events. relates_to = event.content.get("m.relates_to") if isinstance(relates_to, (dict, frozendict)): @@ -445,18 +455,18 @@ class EventClientSerializer: event_id = event.event_id - # The bundled relations to include. - relations = {} + # The bundled aggregations to include. + aggregations = {} annotations = await self.store.get_aggregation_groups_for_event(event_id) if annotations.chunk: - relations[RelationTypes.ANNOTATION] = annotations.to_dict() + aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() references = await self.store.get_relations_for_event( event_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: - relations[RelationTypes.REFERENCE] = references.to_dict() + aggregations[RelationTypes.REFERENCE] = references.to_dict() edit = None if event.type == EventTypes.Message: @@ -482,7 +492,7 @@ class EventClientSerializer: else: serialized_event["content"].pop("m.relates_to", None) - relations[RelationTypes.REPLACE] = { + aggregations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, "sender": edit.sender, @@ -495,17 +505,19 @@ class EventClientSerializer: latest_thread_event, ) = await self.store.get_thread_summary(event_id) if latest_thread_event: - relations[RelationTypes.THREAD] = { - # Don't bundle relations as this could recurse forever. + aggregations[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. "latest_event": await self.serialize_event( - latest_thread_event, time_now, bundle_relations=False + latest_thread_event, time_now, bundle_aggregations=False ), "count": thread_count, } - # If any bundled relations were found, include them. - if relations: - serialized_event["unsigned"].setdefault("m.relations", {}).update(relations) + # If any bundled aggregations were found, include them. + if aggregations: + serialized_event["unsigned"].setdefault("m.relations", {}).update( + aggregations + ) async def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index b4ff935546..32b0254c5f 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -122,9 +122,8 @@ class EventStreamHandler: events, time_now, as_client_event=as_client_event, - # We don't bundle "live" events, as otherwise clients - # will end up double counting annotations. - bundle_relations=False, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, ) chunk = { diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index d4e4556155..9cd21e7f2b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -165,7 +165,11 @@ class InitialSyncHandler: invite_event = await self.store.get_event(event.event_id) d["invite"] = await self._event_serializer.serialize_event( - invite_event, time_now, as_client_event + invite_event, + time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, + as_client_event=as_client_event, ) rooms_ret.append(d) @@ -216,7 +220,11 @@ class InitialSyncHandler: d["messages"] = { "chunk": ( await self._event_serializer.serialize_events( - messages, time_now=time_now, as_client_event=as_client_event + messages, + time_now=time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, + as_client_event=as_client_event, ) ), "start": await start_token.to_string(self.store), @@ -226,6 +234,8 @@ class InitialSyncHandler: d["state"] = await self._event_serializer.serialize_events( current_state.values(), time_now=time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, as_client_event=as_client_event, ) @@ -366,14 +376,18 @@ class InitialSyncHandler: "room_id": room_id, "messages": { "chunk": ( - await self._event_serializer.serialize_events(messages, time_now) + # Don't bundle aggregations as this is a deprecated API. + await self._event_serializer.serialize_events( + messages, time_now, bundle_aggregations=False + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), }, "state": ( + # Don't bundle aggregations as this is a deprecated API. await self._event_serializer.serialize_events( - room_state.values(), time_now + room_state.values(), time_now, bundle_aggregations=False ) ), "presence": [], @@ -392,8 +406,9 @@ class InitialSyncHandler: # TODO: These concurrently time_now = self.clock.time_msec() + # Don't bundle aggregations as this is a deprecated API. state = await self._event_serializer.serialize_events( - current_state.values(), time_now + current_state.values(), time_now, bundle_aggregations=False ) now_token = self.hs.get_event_sources().get_current_token() @@ -467,7 +482,10 @@ class InitialSyncHandler: "room_id": room_id, "messages": { "chunk": ( - await self._event_serializer.serialize_events(messages, time_now) + # Don't bundle aggregations as this is a deprecated API. + await self._event_serializer.serialize_events( + messages, time_now, bundle_aggregations=False + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 95b4fad3c6..87f671708c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -247,13 +247,7 @@ class MessageHandler: room_state = room_state_events[membership_event_id] now = self.clock.time_msec() - events = await self._event_serializer.serialize_events( - room_state.values(), - now, - # We don't bother bundling aggregations in when asked for state - # events, as clients won't use them. - bundle_relations=False, - ) + events = await self._event_serializer.serialize_events(room_state.values(), now) return events async def get_joined_members(self, requester: Requester, room_id: str) -> dict: diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 6bbc5510f0..669ab44a45 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -449,13 +449,7 @@ class RoomStateRestServlet(RestServlet): event_ids = await self.store.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() - room_state = await self._event_serializer.serialize_events( - events.values(), - now, - # We don't bother bundling aggregations in when asked for state - # events, as clients won't use them. - bundle_relations=False, - ) + room_state = await self._event_serializer.serialize_events(events.values(), now) ret = {"state": room_state} return HTTPStatus.OK, ret @@ -789,10 +783,7 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_relations=False, + results["state"], time_now ) return HTTPStatus.OK, results diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index b1a3304849..fc4e6921c5 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -224,14 +224,13 @@ class RelationPaginationServlet(RestServlet): ) now = self.clock.time_msec() - # We set bundle_relations to False when retrieving the original - # event because we want the content before relations were applied to - # it. + # Do not bundle aggregations when retrieving the original event because + # we want the content before relations are applied to it. original_event = await self._event_serializer.serialize_event( - event, now, bundle_relations=False + event, now, bundle_aggregations=False ) # The relations returned for the requested event do include their - # bundled relations. + # bundled aggregations. serialized_events = await self._event_serializer.serialize_events(events, now) return_value = pagination_chunk.to_dict() diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 3598967be0..f48e2e6ca2 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -716,10 +716,7 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_relations=False, + results["state"], time_now ) return 200, results diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index b6a2485732..88e4f5e063 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -520,9 +520,9 @@ class SyncRestServlet(RestServlet): return self._event_serializer.serialize_events( events, time_now=time_now, - # We don't bundle "live" events, as otherwise clients - # will end up double counting annotations. - bundle_relations=False, + # Don't bother to bundle aggregations if the timeline is unlimited, + # as clients will have all the necessary information. + bundle_aggregations=room.timeline.limited, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index b494da5138..397c12c2a6 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Tuple from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin -from synapse.rest.client import login, register, relations, room +from synapse.rest.client import login, register, relations, room, sync from tests import unittest from tests.server import FakeChannel @@ -29,6 +29,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): servlets = [ relations.register_servlets, room.register_servlets, + sync.register_servlets, login.register_servlets, register.register_servlets, admin.register_servlets_for_client_rest_resource, @@ -454,11 +455,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(400, channel.code, channel.json_body) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) - def test_aggregation_get_event(self): - """Test that annotations, references, and threads get correctly bundled when - getting the parent event. - """ - + def test_bundled_aggregations(self): + """Test that annotations, references, and threads get correctly bundled.""" + # Setup by sending a variety of relations. channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -485,49 +484,107 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) + def assert_bundle(actual): + """Assert the expected values of the bundled aggregations.""" - self.assertEquals( - channel.json_body["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { + # Ensure the fields are as expected. + self.assertCountEqual( + actual.keys(), + ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.THREAD, + ), + ) + + # Check the values of each field. + self.assertEquals( + { "chunk": [ {"type": "m.reaction", "key": "a", "count": 2}, {"type": "m.reaction", "key": "b", "count": 1}, ] }, - RelationTypes.REFERENCE: { - "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] - }, - RelationTypes.THREAD: { - "count": 2, - "latest_event": { - "age": 100, - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "origin_server_ts": 1600, - "room_id": self.room, - "sender": self.user_id, - "type": "m.room.test", - "unsigned": {"age": 100}, - "user_id": self.user_id, + actual[RelationTypes.ANNOTATION], + ) + + self.assertEquals( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + actual[RelationTypes.REFERENCE], + ) + + self.assertEquals( + 2, + actual[RelationTypes.THREAD].get("count"), + ) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } }, + "event_id": thread_2, + "room_id": self.room, + "sender": self.user_id, + "type": "m.room.test", + "user_id": self.user_id, }, - }, + actual[RelationTypes.THREAD].get("latest_event"), + ) + + def _find_and_assert_event(events): + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + break + else: + raise AssertionError(f"Event {self.parent_id} not found in chunk") + assert_bundle(event["unsigned"].get("m.relations")) + + # Request the event directly. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["unsigned"].get("m.relations")) + + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, ) + self.assertEquals(200, channel.code, channel.json_body) + _find_and_assert_event(channel.json_body["chunk"]) + + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) + + # Request sync. + channel = self.make_request("GET", "/sync", access_token=self.user_token) + self.assertEquals(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + _find_and_assert_event(room_timeline["events"]) + + # Note that /relations is tested separately in test_aggregation_get_event_for_thread + # since it needs different data configured. def test_aggregation_get_event_for_annotation(self): - """Test that annotations do not get bundled relations included + """Test that annotations do not get bundled aggregations included when directly requested. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -549,7 +606,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) def test_aggregation_get_event_for_thread(self): - """Test that threads get bundled relations included when directly requested.""" + """Test that threads get bundled aggregations included when directly requested.""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") self.assertEquals(200, channel.code, channel.json_body) thread_id = channel.json_body["event_id"] -- cgit 1.5.1 From 8b4b153c9e86c04c7db8c74fde4b6a04becbc461 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 6 Dec 2021 17:59:50 +0100 Subject: Add admin API to get some information about federation status (#11407) --- changelog.d/11407.feature | 1 + docs/SUMMARY.md | 1 + docs/usage/administration/admin_api/federation.md | 114 ++++++ synapse/rest/admin/__init__.py | 6 + synapse/rest/admin/federation.py | 135 +++++++ synapse/storage/databases/main/transactions.py | 70 ++++ tests/rest/admin/test_federation.py | 456 ++++++++++++++++++++++ 7 files changed, 783 insertions(+) create mode 100644 changelog.d/11407.feature create mode 100644 docs/usage/administration/admin_api/federation.md create mode 100644 synapse/rest/admin/federation.py create mode 100644 tests/rest/admin/test_federation.py (limited to 'synapse/rest') diff --git a/changelog.d/11407.feature b/changelog.d/11407.feature new file mode 100644 index 0000000000..1d21bde98f --- /dev/null +++ b/changelog.d/11407.feature @@ -0,0 +1 @@ +Add admin API to get some information about federation status with remote servers. \ No newline at end of file diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 41c8f0fbc9..b05af6d690 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -65,6 +65,7 @@ - [Statistics](admin_api/statistics.md) - [Users](admin_api/user_admin_api.md) - [Server Version](admin_api/version_api.md) + - [Federation](usage/administration/admin_api/federation.md) - [Manhole](manhole.md) - [Monitoring](metrics-howto.md) - [Understanding Synapse Through Grafana Graphs](usage/administration/understanding_synapse_through_grafana_graphs.md) diff --git a/docs/usage/administration/admin_api/federation.md b/docs/usage/administration/admin_api/federation.md new file mode 100644 index 0000000000..8f9535f57b --- /dev/null +++ b/docs/usage/administration/admin_api/federation.md @@ -0,0 +1,114 @@ +# Federation API + +This API allows a server administrator to manage Synapse's federation with other homeservers. + +Note: This API is new, experimental and "subject to change". + +## List of destinations + +This API gets the current destination retry timing info for all remote servers. + +The list contains all the servers with which the server federates, +regardless of whether an error occurred or not. +If an error occurs, it may take up to 20 minutes for the error to be displayed here, +as a complete retry must have failed. + +The API is: + +A standard request with no filtering: + +``` +GET /_synapse/admin/v1/federation/destinations +``` + +A response body like the following is returned: + +```json +{ + "destinations":[ + { + "destination": "matrix.org", + "retry_last_ts": 1557332397936, + "retry_interval": 3000000, + "failure_ts": 1557329397936, + "last_successful_stream_ordering": null + } + ], + "total": 1 +} +``` + +To paginate, check for `next_token` and if present, call the endpoint again +with `from` set to the value of `next_token`. This will return a new page. + +If the endpoint does not return a `next_token` then there are no more destinations +to paginate through. + +**Parameters** + +The following query parameters are available: + +- `from` - Offset in the returned list. Defaults to `0`. +- `limit` - Maximum amount of destinations to return. Defaults to `100`. +- `order_by` - The method in which to sort the returned list of destinations. + Valid values are: + - `destination` - Destinations are ordered alphabetically by remote server name. + This is the default. + - `retry_last_ts` - Destinations are ordered by time of last retry attempt in ms. + - `retry_interval` - Destinations are ordered by how long until next retry in ms. + - `failure_ts` - Destinations are ordered by when the server started failing in ms. + - `last_successful_stream_ordering` - Destinations are ordered by the stream ordering + of the most recent successfully-sent PDU. +- `dir` - Direction of room order. Either `f` for forwards or `b` for backwards. Setting + this value to `b` will reverse the above sort order. Defaults to `f`. + +*Caution:* The database only has an index on the column `destination`. +This means that if a different sort order is used, +this can cause a large load on the database, especially for large environments. + +**Response** + +The following fields are returned in the JSON response body: + +- `destinations` - An array of objects, each containing information about a destination. + Destination objects contain the following fields: + - `destination` - string - Name of the remote server to federate. + - `retry_last_ts` - integer - The last time Synapse tried and failed to reach the + remote server, in ms. This is `0` if the last attempt to communicate with the + remote server was successful. + - `retry_interval` - integer - How long since the last time Synapse tried to reach + the remote server before trying again, in ms. This is `0` if no further retrying occuring. + - `failure_ts` - nullable integer - The first time Synapse tried and failed to reach the + remote server, in ms. This is `null` if communication with the remote server has never failed. + - `last_successful_stream_ordering` - nullable integer - The stream ordering of the most + recent successfully-sent [PDU](understanding_synapse_through_grafana_graphs.md#federation) + to this destination, or `null` if this information has not been tracked yet. +- `next_token`: string representing a positive integer - Indication for pagination. See above. +- `total` - integer - Total number of destinations. + +# Destination Details API + +This API gets the retry timing info for a specific remote server. + +The API is: + +``` +GET /_synapse/admin/v1/federation/destinations/ +``` + +A response body like the following is returned: + +```json +{ + "destination": "matrix.org", + "retry_last_ts": 1557332397936, + "retry_interval": 3000000, + "failure_ts": 1557329397936, + "last_successful_stream_ordering": null +} +``` + +**Response** + +The response fields are the same like in the `destinations` array in +[List of destinations](#list-of-destinations) response. diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c51a029bf3..c499afd4be 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -40,6 +40,10 @@ from synapse.rest.admin.event_reports import ( EventReportDetailRestServlet, EventReportsRestServlet, ) +from synapse.rest.admin.federation import ( + DestinationsRestServlet, + ListDestinationsRestServlet, +) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.registration_tokens import ( @@ -261,6 +265,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListRegistrationTokensRestServlet(hs).register(http_server) NewRegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server) + DestinationsRestServlet(hs).register(http_server) + ListDestinationsRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if hs.config.worker.worker_app is None: diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py new file mode 100644 index 0000000000..744687be35 --- /dev/null +++ b/synapse/rest/admin/federation.py @@ -0,0 +1,135 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.storage.databases.main.transactions import DestinationSortOrder +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ListDestinationsRestServlet(RestServlet): + """Get request to list all destinations. + This needs user to have administrator access in Synapse. + + GET /_synapse/admin/v1/federation/destinations?from=0&limit=10 + + returns: + 200 OK with list of destinations if success otherwise an error. + + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `destination` can be used to filter by destination. + The parameter `order_by` can be used to order the result. + """ + + PATTERNS = admin_patterns("/federation/destinations$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + destination = parse_string(request, "destination") + + order_by = parse_string( + request, + "order_by", + default=DestinationSortOrder.DESTINATION.value, + allowed_values=[dest.value for dest in DestinationSortOrder], + ) + + direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) + + destinations, total = await self._store.get_destinations_paginate( + start, limit, destination, order_by, direction + ) + response = {"destinations": destinations, "total": total} + if (start + limit) < total: + response["next_token"] = str(start + len(destinations)) + + return HTTPStatus.OK, response + + +class DestinationsRestServlet(RestServlet): + """Get details of a destination. + This needs user to have administrator access in Synapse. + + GET /_synapse/admin/v1/federation/destinations/ + + returns: + 200 OK with details of a destination if success otherwise an error. + """ + + PATTERNS = admin_patterns("/federation/destinations/(?P[^/]+)$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET( + self, request: SynapseRequest, destination: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + destination_retry_timings = await self._store.get_destination_retry_timings( + destination + ) + + if not destination_retry_timings: + raise NotFoundError("Unknown destination") + + last_successful_stream_ordering = ( + await self._store.get_destination_last_successful_stream_ordering( + destination + ) + ) + + response = { + "destination": destination, + "failure_ts": destination_retry_timings.failure_ts, + "retry_last_ts": destination_retry_timings.retry_last_ts, + "retry_interval": destination_retry_timings.retry_interval, + "last_successful_stream_ordering": last_successful_stream_ordering, + } + + return HTTPStatus.OK, response diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index d7dc1f73ac..1622822552 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -14,6 +14,7 @@ import logging from collections import namedtuple +from enum import Enum from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple import attr @@ -44,6 +45,16 @@ _UpdateTransactionRow = namedtuple( ) +class DestinationSortOrder(Enum): + """Enum to define the sorting method used when returning destinations.""" + + DESTINATION = "destination" + RETRY_LAST_TS = "retry_last_ts" + RETTRY_INTERVAL = "retry_interval" + FAILURE_TS = "failure_ts" + LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class DestinationRetryTimings: """The current destination retry timing info for a remote server.""" @@ -480,3 +491,62 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): destinations = [row[0] for row in txn] return destinations + + async def get_destinations_paginate( + self, + start: int, + limit: int, + destination: Optional[str] = None, + order_by: str = DestinationSortOrder.DESTINATION.value, + direction: str = "f", + ) -> Tuple[List[JsonDict], int]: + """Function to retrieve a paginated list of destinations. + This will return a json list of destinations and the + total number of destinations matching the filter criteria. + + Args: + start: start number to begin the query from + limit: number of rows to retrieve + destination: search string in destination + order_by: the sort order of the returned list + direction: sort ascending or descending + Returns: + A tuple of a list of mappings from destination to information + and a count of total destinations. + """ + + def get_destinations_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: + order_by_column = DestinationSortOrder(order_by).value + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + args = [] + where_statement = "" + if destination: + args.extend(["%" + destination.lower() + "%"]) + where_statement = "WHERE LOWER(destination) LIKE ?" + + sql_base = f"FROM destinations {where_statement} " + sql = f"SELECT COUNT(*) as total_destinations {sql_base}" + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = f""" + SELECT destination, retry_last_ts, retry_interval, failure_ts, + last_successful_stream_ordering + {sql_base} + ORDER BY {order_by_column} {order}, destination ASC + LIMIT ? OFFSET ? + """ + txn.execute(sql, args + [limit, start]) + destinations = self.db_pool.cursor_to_dict(txn) + return destinations, count + + return await self.db_pool.runInteraction( + "get_destinations_paginate_txn", get_destinations_paginate_txn + ) diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py new file mode 100644 index 0000000000..5188499ef2 --- /dev/null +++ b/tests/rest/admin/test_federation.py @@ -0,0 +1,456 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from http import HTTPStatus +from typing import List, Optional + +from parameterized import parameterized + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.types import JsonDict + +from tests import unittest + + +class FederationTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): + self.store = hs.get_datastore() + self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.url = "/_synapse/admin/v1/federation/destinations" + + @parameterized.expand( + [ + ("/_synapse/admin/v1/federation/destinations",), + ("/_synapse/admin/v1/federation/destinations/dummy",), + ] + ) + def test_requester_is_no_admin(self, url: str): + """ + If the user is not a server admin, an error 403 is returned. + """ + + self.register_user("user", "pass", admin=False) + other_user_tok = self.login("user", "pass") + + channel = self.make_request( + "GET", + url, + content={}, + access_token=other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + + # negative limit + channel = self.make_request( + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative from + channel = self.make_request( + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # unkown order_by + channel = self.make_request( + "GET", + self.url + "?order_by=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid search order + channel = self.make_request( + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid destination + channel = self.make_request( + "GET", + self.url + "/dummy", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_limit(self): + """ + Testing list of destinations with limit + """ + + number_destinations = 20 + self._create_destinations(number_destinations) + + channel = self.make_request( + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), 5) + self.assertEqual(channel.json_body["next_token"], "5") + self._check_fields(channel.json_body["destinations"]) + + def test_from(self): + """ + Testing list of destinations with a defined starting point (from) + """ + + number_destinations = 20 + self._create_destinations(number_destinations) + + channel = self.make_request( + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["destinations"]) + + def test_limit_and_from(self): + """ + Testing list of destinations with a defined starting point and limit + """ + + number_destinations = 20 + self._create_destinations(number_destinations) + + channel = self.make_request( + "GET", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(channel.json_body["next_token"], "15") + self.assertEqual(len(channel.json_body["destinations"]), 10) + self._check_fields(channel.json_body["destinations"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_destinations = 20 + self._create_destinations(number_destinations) + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=20", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), number_destinations) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=21", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), number_destinations) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=19", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), 19) + self.assertEqual(channel.json_body["next_token"], "19") + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", + self.url + "?from=19", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def test_list_all_destinations(self): + """ + List all destinations. + """ + number_destinations = 5 + self._create_destinations(number_destinations) + + channel = self.make_request( + "GET", + self.url, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(number_destinations, len(channel.json_body["destinations"])) + self.assertEqual(number_destinations, channel.json_body["total"]) + + # Check that all fields are available + self._check_fields(channel.json_body["destinations"]) + + def test_order_by(self): + """ + Testing order list with parameter `order_by` + """ + + def _order_test( + expected_destination_list: List[str], + order_by: Optional[str], + dir: Optional[str] = None, + ): + """Request the list of destinations in a certain order. + Assert that order is what we expect + + Args: + expected_destination_list: The list of user_id in the order + we expect to get back from the server + order_by: The type of ordering to give the server + dir: The direction of ordering to give the server + """ + + url = f"{self.url}?" + if order_by is not None: + url += f"order_by={order_by}&" + if dir is not None and dir in ("b", "f"): + url += f"dir={dir}" + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], len(expected_destination_list)) + + returned_order = [ + row["destination"] for row in channel.json_body["destinations"] + ] + self.assertEqual(expected_destination_list, returned_order) + self._check_fields(channel.json_body["destinations"]) + + # create destinations + dest = [ + ("sub-a.example.com", 100, 300, 200, 300), + ("sub-b.example.com", 200, 200, 100, 100), + ("sub-c.example.com", 300, 100, 300, 200), + ] + for ( + destination, + failure_ts, + retry_last_ts, + retry_interval, + last_successful_stream_ordering, + ) in dest: + self.get_success( + self.store.set_destination_retry_timings( + destination, failure_ts, retry_last_ts, retry_interval + ) + ) + self.get_success( + self.store.set_destination_last_successful_stream_ordering( + destination, last_successful_stream_ordering + ) + ) + + # order by default (destination) + _order_test([dest[0][0], dest[1][0], dest[2][0]], None) + _order_test([dest[0][0], dest[1][0], dest[2][0]], None, "f") + _order_test([dest[2][0], dest[1][0], dest[0][0]], None, "b") + + # order by destination + _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination") + _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination", "f") + _order_test([dest[2][0], dest[1][0], dest[0][0]], "destination", "b") + + # order by failure_ts + _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts") + _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts", "f") + _order_test([dest[2][0], dest[1][0], dest[0][0]], "failure_ts", "b") + + # order by retry_last_ts + _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts") + _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts", "f") + _order_test([dest[0][0], dest[1][0], dest[2][0]], "retry_last_ts", "b") + + # order by retry_interval + _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval") + _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval", "f") + _order_test([dest[2][0], dest[0][0], dest[1][0]], "retry_interval", "b") + + # order by last_successful_stream_ordering + _order_test( + [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering" + ) + _order_test( + [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering", "f" + ) + _order_test( + [dest[0][0], dest[2][0], dest[1][0]], "last_successful_stream_ordering", "b" + ) + + def test_search_term(self): + """Test that searching for a destination works correctly""" + + def _search_test( + expected_destination: Optional[str], + search_term: str, + ): + """Search for a destination and check that the returned destinationis a match + + Args: + expected_destination: The room_id expected to be returned by the API. + Set to None to expect zero results for the search + search_term: The term to search for room names with + """ + url = f"{self.url}?destination={search_term}" + channel = self.make_request( + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # Check that destinations were returned + self.assertTrue("destinations" in channel.json_body) + self._check_fields(channel.json_body["destinations"]) + destinations = channel.json_body["destinations"] + + # Check that the expected number of destinations were returned + expected_destination_count = 1 if expected_destination else 0 + self.assertEqual(len(destinations), expected_destination_count) + self.assertEqual(channel.json_body["total"], expected_destination_count) + + if expected_destination: + # Check that the first returned destination is correct + self.assertEqual(expected_destination, destinations[0]["destination"]) + + number_destinations = 3 + self._create_destinations(number_destinations) + + # Test searching + _search_test("sub0.example.com", "0") + _search_test("sub0.example.com", "sub0") + + _search_test("sub1.example.com", "1") + _search_test("sub1.example.com", "1.") + + # Test case insensitive + _search_test("sub0.example.com", "SUB0") + + _search_test(None, "foo") + _search_test(None, "bar") + + def test_get_single_destination(self): + """ + Get one specific destinations. + """ + self._create_destinations(5) + + channel = self.make_request( + "GET", + self.url + "/sub0.example.com", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual("sub0.example.com", channel.json_body["destination"]) + + # Check that all fields are available + # convert channel.json_body into a List + self._check_fields([channel.json_body]) + + def _create_destinations(self, number_destinations: int): + """Create a number of destinations + + Args: + number_destinations: Number of destinations to be created + """ + for i in range(0, number_destinations): + dest = f"sub{i}.example.com" + self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50)) + self.get_success( + self.store.set_destination_last_successful_stream_ordering(dest, 100) + ) + + def _check_fields(self, content: List[JsonDict]): + """Checks that the expected destination attributes are present in content + + Args: + content: List that is checked for content + """ + for c in content: + self.assertIn("destination", c) + self.assertIn("retry_last_ts", c) + self.assertIn("retry_interval", c) + self.assertIn("failure_ts", c) + self.assertIn("last_successful_stream_ordering", c) -- cgit 1.5.1 From a15a893df8428395df7cb95b729431575001c38a Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 6 Dec 2021 18:43:06 +0100 Subject: Save the OIDC session ID (sid) with the device on login (#11482) As a step towards allowing back-channel logout for OIDC. --- changelog.d/11482.misc | 1 + synapse/handlers/auth.py | 34 +++++- synapse/handlers/device.py | 8 ++ synapse/handlers/oidc.py | 58 +++++---- synapse/handlers/register.py | 15 ++- synapse/handlers/sso.py | 4 + synapse/module_api/__init__.py | 2 + synapse/replication/http/login.py | 8 ++ synapse/rest/client/login.py | 7 +- synapse/storage/databases/main/devices.py | 50 +++++++- .../delta/65/11_devices_auth_provider_session.sql | 27 +++++ tests/handlers/test_auth.py | 6 +- tests/handlers/test_cas.py | 40 +++++- tests/handlers/test_oidc.py | 135 ++++++++++++++++++--- tests/handlers/test_saml.py | 40 +++++- 15 files changed, 370 insertions(+), 65 deletions(-) create mode 100644 changelog.d/11482.misc create mode 100644 synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql (limited to 'synapse/rest') diff --git a/changelog.d/11482.misc b/changelog.d/11482.misc new file mode 100644 index 0000000000..e78662988f --- /dev/null +++ b/changelog.d/11482.misc @@ -0,0 +1 @@ +Save the OpenID Connect session ID on login. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 4d9c4e5834..61607cf2ba 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -39,6 +39,7 @@ import attr import bcrypt import pymacaroons import unpaddedbase64 +from pymacaroons.exceptions import MacaroonVerificationFailedException from twisted.web.server import Request @@ -182,8 +183,11 @@ class LoginTokenAttributes: user_id = attr.ib(type=str) - # the SSO Identity Provider that the user authenticated with, to get this token auth_provider_id = attr.ib(type=str) + """The SSO Identity Provider that the user authenticated with, to get this token.""" + + auth_provider_session_id = attr.ib(type=Optional[str]) + """The session ID advertised by the SSO Identity Provider.""" class AuthHandler: @@ -1650,6 +1654,7 @@ class AuthHandler: client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, new_user: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> None: """Having figured out a mxid for this user, complete the HTTP request @@ -1665,6 +1670,7 @@ class AuthHandler: during successful login. Must be JSON serializable. new_user: True if we should use wording appropriate to a user who has just registered. + auth_provider_session_id: The session ID from the SSO IdP received during login. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1685,6 +1691,7 @@ class AuthHandler: extra_attributes, new_user=new_user, user_profile_data=profile, + auth_provider_session_id=auth_provider_session_id, ) def _complete_sso_login( @@ -1696,6 +1703,7 @@ class AuthHandler: extra_attributes: Optional[JsonDict] = None, new_user: bool = False, user_profile_data: Optional[ProfileInfo] = None, + auth_provider_session_id: Optional[str] = None, ) -> None: """ The synchronous portion of complete_sso_login. @@ -1717,7 +1725,9 @@ class AuthHandler: # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( - registered_user_id, auth_provider_id=auth_provider_id + registered_user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) # Append the login token to the original redirect URL (i.e. with its query @@ -1822,6 +1832,7 @@ class MacaroonGenerator: self, user_id: str, auth_provider_id: str, + auth_provider_session_id: Optional[str] = None, duration_in_ms: int = (2 * 60 * 1000), ) -> str: macaroon = self._generate_base_macaroon(user_id) @@ -1830,6 +1841,10 @@ class MacaroonGenerator: expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,)) + if auth_provider_session_id is not None: + macaroon.add_first_party_caveat( + "auth_provider_session_id = %s" % (auth_provider_session_id,) + ) return macaroon.serialize() def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: @@ -1851,15 +1866,28 @@ class MacaroonGenerator: user_id = get_value_from_macaroon(macaroon, "user_id") auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") + auth_provider_session_id: Optional[str] = None + try: + auth_provider_session_id = get_value_from_macaroon( + macaroon, "auth_provider_session_id" + ) + except MacaroonVerificationFailedException: + pass + v = pymacaroons.Verifier() v.satisfy_exact("gen = 1") v.satisfy_exact("type = login") v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) + v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = ")) satisfy_expiry(v, self.hs.get_clock().time_msec) v.verify(macaroon, self.hs.config.key.macaroon_secret_key) - return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id) + return LoginTokenAttributes( + user_id=user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) def generate_delete_pusher_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 68b446eb66..82ee11e921 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id: str, device_id: Optional[str], initial_device_display_name: Optional[str] = None, + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> str: """ If the given device has not been registered, register it with the @@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id: @user:id device_id: device id supplied by client initial_device_display_name: device display name from client + auth_provider_id: The SSO IdP the user used, if any. + auth_provider_session_id: The session ID (sid) got from the SSO IdP. Returns: device id (generated if none was supplied) """ @@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if new_device: await self.notify_device_update(user_id, [device_id]) @@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=new_device_id, initial_device_display_name=initial_device_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if new_device: await self.notify_device_update(user_id, [new_device_id]) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 3665d91513..deb3539751 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -23,7 +23,7 @@ from authlib.common.security import generate_token from authlib.jose import JsonWebToken, jwt from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri -from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo +from authlib.oidc.core import CodeIDToken, UserInfo from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url from jinja2 import Environment, Template from pymacaroons.exceptions import ( @@ -117,7 +117,8 @@ class OidcHandler: for idp_id, p in self._providers.items(): try: await p.load_metadata() - await p.load_jwks() + if not p._uses_userinfo: + await p.load_jwks() except Exception as e: raise Exception( "Error while initialising OIDC provider %r" % (idp_id,) @@ -498,10 +499,6 @@ class OidcProvider: return await self._jwks.get() async def _load_jwks(self) -> JWKS: - if self._uses_userinfo: - # We're not using jwt signing, return an empty jwk set - return {"keys": []} - metadata = await self.load_metadata() # Load the JWKS using the `jwks_uri` metadata. @@ -663,7 +660,7 @@ class OidcProvider: return UserInfo(resp) - async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: + async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: """Return an instance of UserInfo from token's ``id_token``. Args: @@ -673,7 +670,7 @@ class OidcProvider: request. This value should match the one inside the token. Returns: - An object representing the user. + The decoded claims in the ID token. """ metadata = await self.load_metadata() claims_params = { @@ -684,9 +681,6 @@ class OidcProvider: # If we got an `access_token`, there should be an `at_hash` claim # in the `id_token` that we can check against. claims_params["access_token"] = token["access_token"] - claims_cls = CodeIDToken - else: - claims_cls = ImplicitIDToken alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) jwt = JsonWebToken(alg_values) @@ -703,7 +697,7 @@ class OidcProvider: claims = jwt.decode( id_token, key=jwk_set, - claims_cls=claims_cls, + claims_cls=CodeIDToken, claims_options=claim_options, claims_params=claims_params, ) @@ -713,7 +707,7 @@ class OidcProvider: claims = jwt.decode( id_token, key=jwk_set, - claims_cls=claims_cls, + claims_cls=CodeIDToken, claims_options=claim_options, claims_params=claims_params, ) @@ -721,7 +715,8 @@ class OidcProvider: logger.debug("Decoded id_token JWT %r; validating", claims) claims.validate(leeway=120) # allows 2 min of clock skew - return UserInfo(claims) + + return claims async def handle_redirect_request( self, @@ -837,8 +832,22 @@ class OidcProvider: logger.debug("Successfully obtained OAuth2 token data: %r", token) - # Now that we have a token, get the userinfo, either by decoding the - # `id_token` or by fetching the `userinfo_endpoint`. + # If there is an id_token, it should be validated, regardless of the + # userinfo endpoint is used or not. + if token.get("id_token") is not None: + try: + id_token = await self._parse_id_token(token, nonce=session_data.nonce) + sid = id_token.get("sid") + except Exception as e: + logger.exception("Invalid id_token") + self._sso_handler.render_error(request, "invalid_token", str(e)) + return + else: + id_token = None + sid = None + + # Now that we have a token, get the userinfo either from the `id_token` + # claims or by fetching the `userinfo_endpoint`. if self._uses_userinfo: try: userinfo = await self._fetch_userinfo(token) @@ -846,13 +855,14 @@ class OidcProvider: logger.exception("Could not fetch userinfo") self._sso_handler.render_error(request, "fetch_error", str(e)) return + elif id_token is not None: + userinfo = UserInfo(id_token) else: - try: - userinfo = await self._parse_id_token(token, nonce=session_data.nonce) - except Exception as e: - logger.exception("Invalid id_token") - self._sso_handler.render_error(request, "invalid_token", str(e)) - return + logger.error("Missing id_token in token response") + self._sso_handler.render_error( + request, "invalid_token", "Missing id_token in token response" + ) + return # first check if we're doing a UIA if session_data.ui_auth_session_id: @@ -884,7 +894,7 @@ class OidcProvider: # Call the mapper to register/login the user try: await self._complete_oidc_login( - userinfo, token, request, session_data.client_redirect_url + userinfo, token, request, session_data.client_redirect_url, sid ) except MappingException as e: logger.exception("Could not map user") @@ -896,6 +906,7 @@ class OidcProvider: token: Token, request: SynapseRequest, client_redirect_url: str, + sid: Optional[str], ) -> None: """Given a UserInfo response, complete the login flow @@ -1008,6 +1019,7 @@ class OidcProvider: oidc_response_to_user_attributes, grandfather_existing_users, extra_attributes, + auth_provider_session_id=sid, ) def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index b14ddd8267..f08a516a75 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -746,6 +746,7 @@ class RegistrationHandler: is_appservice_ghost: bool = False, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> Tuple[str, str, Optional[int], Optional[str]]: """Register a device for a user and generate an access token. @@ -756,9 +757,9 @@ class RegistrationHandler: device_id: The device ID to check, or None to generate a new one. initial_display_name: An optional display name for the device. is_guest: Whether this is a guest account - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). + auth_provider_id: The SSO IdP the user used, if any. should_issue_refresh_token: Whether it should also issue a refresh token + auth_provider_session_id: The session ID received during login from the SSO IdP. Returns: Tuple of device ID, access token, access token expiration time and refresh token """ @@ -769,6 +770,8 @@ class RegistrationHandler: is_guest=is_guest, is_appservice_ghost=is_appservice_ghost, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) login_counter.labels( @@ -791,6 +794,8 @@ class RegistrationHandler: is_guest: bool = False, is_appservice_ghost: bool = False, should_issue_refresh_token: bool = False, + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> LoginDict: """Helper for register_device @@ -822,7 +827,11 @@ class RegistrationHandler: refresh_token_id = None registered_device_id = await self.device_handler.check_device_registered( - user_id, device_id, initial_display_name + user_id, + device_id, + initial_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if is_guest: assert access_token_expiry is None diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 49fde01cf0..65c27bc64a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -365,6 +365,7 @@ class SsoHandler: sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], grandfather_existing_users: Callable[[], Awaitable[Optional[str]]], extra_login_attributes: Optional[JsonDict] = None, + auth_provider_session_id: Optional[str] = None, ) -> None: """ Given an SSO ID, retrieve the user ID for it and possibly register the user. @@ -415,6 +416,8 @@ class SsoHandler: extra_login_attributes: An optional dictionary of extra attributes to be provided to the client in the login response. + auth_provider_session_id: An optional session ID from the IdP. + Raises: MappingException if there was a problem mapping the response to a user. RedirectException: if the mapping provider needs to redirect the user @@ -490,6 +493,7 @@ class SsoHandler: client_redirect_url, extra_login_attributes, new_user=new_user, + auth_provider_session_id=auth_provider_session_id, ) async def _call_attribute_mapper( diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index a8154168be..6bfb4b8d1b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -626,6 +626,7 @@ class ModuleApi: user_id: str, duration_in_ms: int = (2 * 60 * 1000), auth_provider_id: str = "", + auth_provider_session_id: Optional[str] = None, ) -> str: """Generate a login token suitable for m.login.token authentication @@ -643,6 +644,7 @@ class ModuleApi: return self._hs.get_macaroon_generator().generate_short_term_login_token( user_id, auth_provider_id, + auth_provider_session_id, duration_in_ms, ) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 0db419ea57..daacc34cea 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -46,6 +46,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest, is_appservice_ghost, should_issue_refresh_token, + auth_provider_id, + auth_provider_session_id, ): """ Args: @@ -63,6 +65,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "is_guest": is_guest, "is_appservice_ghost": is_appservice_ghost, "should_issue_refresh_token": should_issue_refresh_token, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, } async def _handle_request(self, request, user_id): @@ -73,6 +77,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest = content["is_guest"] is_appservice_ghost = content["is_appservice_ghost"] should_issue_refresh_token = content["should_issue_refresh_token"] + auth_provider_id = content["auth_provider_id"] + auth_provider_session_id = content["auth_provider_session_id"] res = await self.registration_handler.register_device_inner( user_id, @@ -81,6 +87,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest, is_appservice_ghost=is_appservice_ghost, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) return 200, res diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index a66ee4fb3d..1b23fa18cf 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -303,6 +303,7 @@ class LoginRestServlet(RestServlet): ratelimit: bool = True, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> LoginResponse: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -318,10 +319,10 @@ class LoginRestServlet(RestServlet): create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. ratelimit: Whether to ratelimit the login request. - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). + auth_provider_id: The SSO IdP the user used, if any. should_issue_refresh_token: True if this login should issue a refresh token alongside the access token. + auth_provider_session_id: The session ID got during login from the SSO IdP. Returns: result: Dictionary of account information after successful login. @@ -354,6 +355,7 @@ class LoginRestServlet(RestServlet): initial_display_name, auth_provider_id=auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_session_id=auth_provider_session_id, ) result = LoginResponse( @@ -399,6 +401,7 @@ class LoginRestServlet(RestServlet): self.auth_handler._sso_login_callback, auth_provider_id=res.auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_session_id=res.auth_provider_session_id, ) async def _do_jwt_login( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 9ccc66e589..d5a4a661cd 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore): return {d["device_id"]: d for d in devices} + async def get_devices_by_auth_provider_session_id( + self, auth_provider_id: str, auth_provider_session_id: str + ) -> List[Dict[str, Any]]: + """Retrieve the list of devices associated with a SSO IdP session ID. + + Args: + auth_provider_id: The SSO IdP ID as defined in the server config + auth_provider_session_id: The session ID within the IdP + Returns: + A list of dicts containing the device_id and the user_id of each device + """ + return await self.db_pool.simple_select_list( + table="device_auth_providers", + keyvalues={ + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + retcols=("user_id", "device_id"), + desc="get_devices_by_auth_provider_session_id", + ) + @trace async def get_device_updates_by_remote( self, destination: str, from_stream_id: int, limit: int @@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def store_device( - self, user_id: str, device_id: str, initial_device_display_name: Optional[str] + self, + user_id: str, + device_id: str, + initial_device_display_name: Optional[str], + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> bool: """Ensure the given device is known; add it to the store if not @@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_id: id of device initial_device_display_name: initial displayname of the device. Ignored if device exists. + auth_provider_id: The SSO IdP the user used, if any. + auth_provider_session_id: The session ID (sid) got from a OIDC login. Returns: Whether the device was inserted or an existing device existed with that ID. @@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) + if auth_provider_id and auth_provider_session_id: + await self.db_pool.simple_insert( + "device_auth_providers", + values={ + "user_id": user_id, + "device_id": device_id, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + desc="store_device_auth_provider", + ) + self.device_id_exists_cache.set(key, True) return inserted except StoreError: @@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): keyvalues={"user_id": user_id}, ) + self.db_pool.simple_delete_many_txn( + txn, + table="device_auth_providers", + column="device_id", + values=device_ids, + keyvalues={"user_id": user_id}, + ) + await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) diff --git a/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql new file mode 100644 index 0000000000..a65bfb520d --- /dev/null +++ b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql @@ -0,0 +1,27 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Track the auth provider used by each login as well as the session ID +CREATE TABLE device_auth_providers ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + auth_provider_id TEXT NOT NULL, + auth_provider_session_id TEXT NOT NULL +); + +CREATE INDEX device_auth_providers_devices + ON device_auth_providers (user_id, device_id); +CREATE INDEX device_auth_providers_sessions + ON device_auth_providers (auth_provider_id, auth_provider_session_id); diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 72e176da75..03b8b8615c 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -71,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_short_term_login_token_gives_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) self.assertEqual(self.user1, res.user_id) @@ -94,7 +94,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_short_term_login_token_cannot_replace_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) macaroon = pymacaroons.Macaroon.deserialize(token) @@ -213,6 +213,6 @@ class AuthTestCase(unittest.HomeserverTestCase): def _get_macaroon(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) return pymacaroons.Macaroon.deserialize(token) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index b625995d12..8705ff8943 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -66,7 +66,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=True + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) def test_map_cas_user_to_existing_user(self): @@ -89,7 +95,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=False + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=False, + auth_provider_session_id=None, ) # Subsequent calls should map to the same mxid. @@ -98,7 +110,13 @@ class CasHandlerTestCase(HomeserverTestCase): self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=False + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=False, + auth_provider_session_id=None, ) def test_map_cas_user_to_invalid_localpart(self): @@ -116,7 +134,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True + "@f=c3=b6=c3=b6:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -160,7 +184,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=True + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index a25c89bd5b..cfe3de5266 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -252,13 +252,6 @@ class OidcHandlerTestCase(HomeserverTestCase): with patch.object(self.provider, "load_metadata", patched_load_metadata): self.get_failure(self.provider.load_jwks(force=True), RuntimeError) - # Return empty key set if JWKS are not used - self.provider._scopes = [] # not asking the openid scope - self.http_client.get_json.reset_mock() - jwks = self.get_success(self.provider.load_jwks(force=True)) - self.http_client.get_json.assert_not_called() - self.assertEqual(jwks, {"keys": []}) - @override_config({"oidc_config": DEFAULT_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" @@ -455,7 +448,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, "oidc", request, client_redirect_url, None, new_user=True + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=True, + auth_provider_session_id=None, ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -482,17 +481,58 @@ class OidcHandlerTestCase(HomeserverTestCase): self.provider._fetch_userinfo.reset_mock() # With userinfo fetching - self.provider._scopes = [] # do not ask the "openid" scope + self.provider._user_profile_method = "userinfo_endpoint" + token = { + "type": "bearer", + "access_token": "access_token", + } + self.provider._exchange_code = simple_async_mock(return_value=token) self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, "oidc", request, client_redirect_url, None, new_user=False + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=False, + auth_provider_session_id=None, ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_not_called() self.provider._fetch_userinfo.assert_called_once_with(token) self.render_error.assert_not_called() + # With an ID token, userinfo fetching and sid in the ID token + self.provider._user_profile_method = "userinfo_endpoint" + token = { + "type": "bearer", + "access_token": "access_token", + "id_token": "id_token", + } + id_token = { + "sid": "abcdefgh", + } + self.provider._parse_id_token = simple_async_mock(return_value=id_token) + self.provider._exchange_code = simple_async_mock(return_value=token) + auth_handler.complete_sso_login.reset_mock() + self.provider._fetch_userinfo.reset_mock() + self.get_success(self.handler.handle_oidc_callback(request)) + + auth_handler.complete_sso_login.assert_called_once_with( + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=False, + auth_provider_session_id=id_token["sid"], + ) + self.provider._exchange_code.assert_called_once_with(code) + self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.provider._fetch_userinfo.assert_called_once_with(token) + self.render_error.assert_not_called() + # Handle userinfo fetching error self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) @@ -776,6 +816,7 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url, {"phone": "1234567"}, new_user=True, + auth_provider_session_id=None, ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -790,7 +831,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "oidc", ANY, ANY, None, new_user=True + "@test_user:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -801,7 +848,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True + "@test_user_2:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -838,14 +891,26 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -860,7 +925,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -896,7 +967,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False + "@TEST_USER_2:test", + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -934,7 +1011,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", "oidc", ANY, ANY, None, new_user=True + "@test_user1:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -1018,7 +1101,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", "oidc", ANY, ANY, None, new_user=True + "@tester:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -1043,7 +1132,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", "oidc", ANY, ANY, None, new_user=True + "@tester:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -1156,7 +1251,7 @@ async def _make_callback_with_userinfo( handler = hs.get_oidc_handler() provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={}) + provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 8cfc184fef..50551aa6e3 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -130,7 +130,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "redirect_uri", None, new_user=True + "@test_user:test", + "saml", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @@ -156,7 +162,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "", None, new_user=False + "@test_user:test", + "saml", + request, + "", + None, + new_user=False, + auth_provider_session_id=None, ) # Subsequent calls should map to the same mxid. @@ -165,7 +177,13 @@ class SamlHandlerTestCase(HomeserverTestCase): self.handler._handle_authn_response(request, saml_response, "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "", None, new_user=False + "@test_user:test", + "saml", + request, + "", + None, + new_user=False, + auth_provider_session_id=None, ) def test_map_saml_response_to_invalid_localpart(self): @@ -213,7 +231,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", "saml", request, "", None, new_user=True + "@test_user1:test", + "saml", + request, + "", + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -309,7 +333,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "redirect_uri", None, new_user=True + "@test_user:test", + "saml", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) -- cgit 1.5.1 From 2f053f3f82ca174cc1c858c75afffae51af8ce0d Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 6 Dec 2021 19:11:43 +0000 Subject: Stabilise support for MSC2918 refresh tokens as they have now been merged into the Matrix specification. (#11435) --- changelog.d/11435.feature | 1 + docs/sample_config.yaml | 38 ++++++++++++++++++++++++++++++++++++++ synapse/config/registration.py | 38 ++++++++++++++++++++++++++++++++++++++ synapse/rest/client/login.py | 29 +++++++++++++---------------- synapse/rest/client/register.py | 23 ++++++++++------------- tests/rest/client/test_auth.py | 30 +++++++++++++++--------------- 6 files changed, 115 insertions(+), 44 deletions(-) create mode 100644 changelog.d/11435.feature (limited to 'synapse/rest') diff --git a/changelog.d/11435.feature b/changelog.d/11435.feature new file mode 100644 index 0000000000..9e127fae3c --- /dev/null +++ b/changelog.d/11435.feature @@ -0,0 +1 @@ +Stabilise support for [MSC2918](https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens) refresh tokens as they have now been merged into the Matrix specification. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index ae476d19ac..6696ed5d1e 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1209,6 +1209,44 @@ oembed: # #session_lifetime: 24h +# Time that an access token remains valid for, if the session is +# using refresh tokens. +# For more information about refresh tokens, please see the manual. +# Note that this only applies to clients which advertise support for +# refresh tokens. +# +# Note also that this is calculated at login time and refresh time: +# changes are not applied to existing sessions until they are refreshed. +# +# By default, this is 5 minutes. +# +#refreshable_access_token_lifetime: 5m + +# Time that a refresh token remains valid for (provided that it is not +# exchanged for another one first). +# This option can be used to automatically log-out inactive sessions. +# Please see the manual for more information. +# +# Note also that this is calculated at login time and refresh time: +# changes are not applied to existing sessions until they are refreshed. +# +# By default, this is infinite. +# +#refresh_token_lifetime: 24h + +# Time that an access token remains valid for, if the session is NOT +# using refresh tokens. +# Please note that not all clients support refresh tokens, so setting +# this to a short value may be inconvenient for some users who will +# then be logged out frequently. +# +# Note also that this is calculated at login time: changes are not applied +# retrospectively to existing sessions for users that have already logged in. +# +# By default, this is infinite. +# +#nonrefreshable_access_token_lifetime: 24h + # The user must provide all of the below types of 3PID when registering. # #registrations_require_3pid: diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 68a4985398..7a059c6dec 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -220,6 +220,44 @@ class RegistrationConfig(Config): # #session_lifetime: 24h + # Time that an access token remains valid for, if the session is + # using refresh tokens. + # For more information about refresh tokens, please see the manual. + # Note that this only applies to clients which advertise support for + # refresh tokens. + # + # Note also that this is calculated at login time and refresh time: + # changes are not applied to existing sessions until they are refreshed. + # + # By default, this is 5 minutes. + # + #refreshable_access_token_lifetime: 5m + + # Time that a refresh token remains valid for (provided that it is not + # exchanged for another one first). + # This option can be used to automatically log-out inactive sessions. + # Please see the manual for more information. + # + # Note also that this is calculated at login time and refresh time: + # changes are not applied to existing sessions until they are refreshed. + # + # By default, this is infinite. + # + #refresh_token_lifetime: 24h + + # Time that an access token remains valid for, if the session is NOT + # using refresh tokens. + # Please note that not all clients support refresh tokens, so setting + # this to a short value may be inconvenient for some users who will + # then be logged out frequently. + # + # Note also that this is calculated at login time: changes are not applied + # retrospectively to existing sessions for users that have already logged in. + # + # By default, this is infinite. + # + #nonrefreshable_access_token_lifetime: 24h + # The user must provide all of the below types of 3PID when registering. # #registrations_require_3pid: diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 1b23fa18cf..f9994658c4 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -72,7 +72,7 @@ class LoginRestServlet(RestServlet): JWT_TYPE_DEPRECATED = "m.login.jwt" APPSERVICE_TYPE = "m.login.application_service" APPSERVICE_TYPE_UNSTABLE = "uk.half-shot.msc2778.login.application_service" - REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" + REFRESH_TOKEN_PARAM = "refresh_token" def __init__(self, hs: "HomeServer"): super().__init__() @@ -90,7 +90,7 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._msc2918_enabled = ( + self._refresh_tokens_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None ) @@ -163,17 +163,16 @@ class LoginRestServlet(RestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]: login_submission = parse_json_object_from_request(request) - if self._msc2918_enabled: - # Check if this login should also issue a refresh token, as per MSC2918 - should_issue_refresh_token = login_submission.get( - "org.matrix.msc2918.refresh_token", False - ) - if not isinstance(should_issue_refresh_token, bool): - raise SynapseError( - 400, "`org.matrix.msc2918.refresh_token` should be true or false." - ) - else: - should_issue_refresh_token = False + # Check to see if the client requested a refresh token. + client_requested_refresh_token = login_submission.get( + LoginRestServlet.REFRESH_TOKEN_PARAM, False + ) + if not isinstance(client_requested_refresh_token, bool): + raise SynapseError(400, "`refresh_token` should be true or false.") + + should_issue_refresh_token = ( + self._refresh_tokens_enabled and client_requested_refresh_token + ) try: if login_submission["type"] in ( @@ -463,9 +462,7 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: class RefreshTokenServlet(RestServlet): - PATTERNS = client_patterns( - "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True - ) + PATTERNS = (re.compile("^/_matrix/client/v1/refresh$"),) def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 11fd6cd24d..8b56c76aed 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -419,7 +419,7 @@ class RegisterRestServlet(RestServlet): self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_enabled = self.hs.config.registration.enable_registration - self._msc2918_enabled = ( + self._refresh_tokens_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None ) @@ -445,18 +445,15 @@ class RegisterRestServlet(RestServlet): f"Do not understand membership kind: {kind}", ) - if self._msc2918_enabled: - # Check if this registration should also issue a refresh token, as - # per MSC2918 - should_issue_refresh_token = body.get( - "org.matrix.msc2918.refresh_token", False - ) - if not isinstance(should_issue_refresh_token, bool): - raise SynapseError( - 400, "`org.matrix.msc2918.refresh_token` should be true or false." - ) - else: - should_issue_refresh_token = False + # Check if the clients wishes for this registration to issue a refresh + # token. + client_requested_refresh_tokens = body.get("refresh_token", False) + if not isinstance(client_requested_refresh_tokens, bool): + raise SynapseError(400, "`refresh_token` should be true or false.") + + should_issue_refresh_token = ( + self._refresh_tokens_enabled and client_requested_refresh_tokens + ) # Pull out the provided username and do basic sanity checks early since # the auth layer will store these in sessions. diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 7239e1a1b5..aa8ad6d2e1 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -520,7 +520,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): """ return self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": refresh_token}, ) @@ -557,7 +557,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): login_with_refresh = self.make_request( "POST", "/_matrix/client/r0/login", - {"org.matrix.msc2918.refresh_token": True, **body}, + {"refresh_token": True, **body}, ) self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) self.assertIn("refresh_token", login_with_refresh.json_body) @@ -588,7 +588,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "username": "test3", "password": self.user_pass, "auth": {"type": LoginType.DUMMY}, - "org.matrix.msc2918.refresh_token": True, + "refresh_token": True, }, ) self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) @@ -603,7 +603,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "type": "m.login.password", "user": "test", "password": self.user_pass, - "org.matrix.msc2918.refresh_token": True, + "refresh_token": True, } login_response = self.make_request( "POST", @@ -614,7 +614,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, 200, refresh_response.result) @@ -641,7 +641,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "type": "m.login.password", "user": "test", "password": self.user_pass, - "org.matrix.msc2918.refresh_token": True, + "refresh_token": True, } login_response = self.make_request( "POST", @@ -655,7 +655,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, 200, refresh_response.result) @@ -761,7 +761,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "type": "m.login.password", "user": "test", "password": self.user_pass, - "org.matrix.msc2918.refresh_token": True, + "refresh_token": True, } login_response = self.make_request( "POST", @@ -811,7 +811,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "type": "m.login.password", "user": "test", "password": self.user_pass, - "org.matrix.msc2918.refresh_token": True, + "refresh_token": True, } login_response = self.make_request( "POST", @@ -868,7 +868,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "type": "m.login.password", "user": "test", "password": self.user_pass, - "org.matrix.msc2918.refresh_token": True, + "refresh_token": True, } login_response = self.make_request( "POST", @@ -880,7 +880,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This first refresh should work properly first_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -890,7 +890,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one as well, since the token in the first one was never used second_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -900,7 +900,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one should not, since the token from the first refresh is not valid anymore third_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -928,7 +928,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -938,7 +938,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # But refreshing from the last valid refresh token still works fifth_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( -- cgit 1.5.1 From b1ecd19c5d19815b69e425d80f442bf2877cab76 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 7 Dec 2021 11:37:54 +0000 Subject: Fix 'delete room' admin api to work on incomplete rooms (#11523) If, for some reason, we don't have the create event, we should still be able to purge a room. --- changelog.d/11523.feature | 1 + synapse/handlers/pagination.py | 3 --- synapse/handlers/room.py | 21 +++++++-------------- synapse/rest/admin/rooms.py | 3 --- tests/rest/admin/test_room.py | 42 +++++++++++++++++++++++++----------------- 5 files changed, 33 insertions(+), 37 deletions(-) create mode 100644 changelog.d/11523.feature (limited to 'synapse/rest') diff --git a/changelog.d/11523.feature b/changelog.d/11523.feature new file mode 100644 index 0000000000..ecac7f9db9 --- /dev/null +++ b/changelog.d/11523.feature @@ -0,0 +1 @@ +Extend the "delete room" admin api to work correctly on rooms which have previously been partially deleted. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index cd64142735..4f42438053 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -406,9 +406,6 @@ class PaginationHandler: force: set true to skip checking for joined users. """ with await self.pagination_lock.write(room_id): - # check we know about the room - await self.store.get_room_version_id(room_id) - # first check that we have no users in this room if not force: joined = await self.store.is_host_joined(room_id, self._server_name) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2bcdf32dcc..ead2198e14 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1535,20 +1535,13 @@ class RoomShutdownHandler: await self.store.block_room(room_id, requester_user_id) if not await self.store.get_room(room_id): - if block: - # We allow you to block an unknown room. - return { - "kicked_users": [], - "failed_to_kick_users": [], - "local_aliases": [], - "new_room_id": None, - } - else: - # But if you don't want to preventatively block another room, - # this function can't do anything useful. - raise NotFoundError( - "Cannot shut down room: unknown room id %s" % (room_id,) - ) + # if we don't know about the room, there is nothing left to do. + return { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } if new_room_user_id is not None: if not self.hs.is_mine_id(new_room_user_id): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 669ab44a45..829e86675a 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -106,9 +106,6 @@ class RoomRestV2Servlet(RestServlet): HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) ) - if not await self._store.get_room(room_id): - raise NotFoundError("Unknown room id %s" % (room_id,)) - delete_id = self._pagination_handler.start_shutdown_and_purge_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index d3858e460d..22f9aa6234 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -83,7 +83,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def test_room_does_not_exist(self): """ - Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. + Check that unknown rooms/server return 200 """ url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test" @@ -94,8 +94,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) def test_room_is_not_valid(self): """ @@ -508,27 +507,36 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - @parameterized.expand( - [ - ("DELETE", "/_synapse/admin/v2/rooms/%s"), - ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), - ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), - ] - ) - def test_room_does_not_exist(self, method: str, url: str): - """ - Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. + def test_room_does_not_exist(self): """ + Check that unknown rooms/server return 200 + This is important, as it allows incomplete vestiges of rooms to be cleared up + even if the create event/etc is missing. + """ + room_id = "!unknown:test" channel = self.make_request( - method, - url % "!unknown:test", + "DELETE", + f"/_synapse/admin/v2/rooms/{room_id}", content={}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + # get status + channel = self.make_request( + "GET", + f"/_synapse/admin/v2/rooms/{room_id}/delete_status", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"]) @parameterized.expand( [ -- cgit 1.5.1