From 10332c175c829a21b6565bc5e865bca154ffb084 Mon Sep 17 00:00:00 2001 From: David Teller Date: Mon, 18 Jan 2021 15:02:22 +0100 Subject: New API /_synapse/admin/rooms/{roomId}/context/{eventId} Signed-off-by: David Teller --- changelog.d/9150.feature | 4 ++++ synapse/handlers/room.py | 11 +++++++-- synapse/rest/admin/__init__.py | 2 ++ synapse/rest/admin/rooms.py | 53 ++++++++++++++++++++++++++++++++++++++++++ synapse/visibility.py | 26 ++++++++++++++++----- tests/rest/admin/test_room.py | 48 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 136 insertions(+), 8 deletions(-) create mode 100644 changelog.d/9150.feature diff --git a/changelog.d/9150.feature b/changelog.d/9150.feature new file mode 100644 index 0000000000..86c4fd3d72 --- /dev/null +++ b/changelog.d/9150.feature @@ -0,0 +1,4 @@ +New API /_synapse/admin/rooms/{roomId}/context/{eventId} + +This API mirrors /_matrix/client/r0/rooms/{roomId}/context/{eventId} but lets administrators +inspect rooms. Designed to annotate abuse reports with context. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ee27d99135..15e90fdd4e 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1008,6 +1008,7 @@ class RoomContextHandler: event_id: str, limit: int, event_filter: Optional[Filter], + use_admin_priviledge: bool = False, ) -> Optional[JsonDict]: """Retrieves events, pagination tokens and state around a given event in a room. @@ -1020,7 +1021,9 @@ class RoomContextHandler: (excluding state). event_filter: the filter to apply to the events returned (excluding the target event_id) - + use_admin_priviledge: if `True`, return all events, regardless + of whether `user` has access to them. To be used **ONLY** + from the admin API. Returns: dict, or None if the event isn't found """ @@ -1032,7 +1035,11 @@ class RoomContextHandler: def filter_evts(events): return filter_events_for_client( - self.storage, user.to_string(), events, is_peeking=is_peeking + self.storage, + user.to_string(), + events, + is_peeking=is_peeking, + use_admin_priviledge=use_admin_priviledge, ) event = await self.store.get_event( diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 57e0a10837..cb24998b47 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -42,6 +42,7 @@ from synapse.rest.admin.rooms import ( JoinRoomAliasServlet, ListRoomRestServlet, MakeRoomAdminRestServlet, + RoomEventContextServlet, RoomMembersRestServlet, RoomRestServlet, ShutdownRoomRestServlet, @@ -236,6 +237,7 @@ def register_servlets(hs, http_server): MakeRoomAdminRestServlet(hs).register(http_server) ShadowBanRestServlet(hs).register(http_server) ForwardExtremitiesRestServlet(hs).register(http_server) + RoomEventContextServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f14915d47e..6539655289 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -566,3 +566,56 @@ class ForwardExtremitiesRestServlet(RestServlet): extremities = await self.store.get_forward_extremities_for_room(room_id) return 200, {"count": len(extremities), "results": extremities} + +class RoomEventContextServlet(RestServlet): + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/context/(?P[^/]*)$") + + def __init__(self, hs): + super().__init__() + self.clock = hs.get_clock() + self.room_context_handler = hs.get_room_context_handler() + self._event_serializer = hs.get_event_client_serializer() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + limit = parse_integer(request, "limit", default=10) + + # picking the API shape for symmetry with /messages + filter_str = parse_string(request, b"filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) + event_filter = Filter( + json_decoder.decode(filter_json) + ) # type: Optional[Filter] + else: + event_filter = None + + results = await self.room_context_handler.get_event_context( + requester.user, + room_id, + event_id, + limit, + event_filter, + use_admin_priviledge=True, + ) + + if not results: + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + + time_now = self.clock.time_msec() + results["events_before"] = await self._event_serializer.serialize_events( + results["events_before"], time_now + ) + results["event"] = await self._event_serializer.serialize_event( + results["event"], time_now + ) + results["events_after"] = await self._event_serializer.serialize_events( + results["events_after"], time_now + ) + results["state"] = await self._event_serializer.serialize_events( + results["state"], time_now + ) + + return 200, results diff --git a/synapse/visibility.py b/synapse/visibility.py index ec50e7e977..80c304562d 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -53,6 +53,7 @@ async def filter_events_for_client( is_peeking=False, always_include_ids=frozenset(), filter_send_to_client=True, + use_admin_priviledge=False, ): """ Check which events a user is allowed to see. If the user can see the event but its @@ -71,6 +72,9 @@ async def filter_events_for_client( filter_send_to_client (bool): Whether we're checking an event that's going to be sent to a client. This might not always be the case since this function can also be called to check whether a user can see the state at a given point. + use_admin_priviledge: if `True`, return all events, regardless + of whether `user` has access to them. To be used **ONLY** + from the admin API. Returns: list[synapse.events.EventBase] @@ -79,15 +83,23 @@ async def filter_events_for_client( # to clients. events = [e for e in events if not e.internal_metadata.is_soft_failed()] - types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) + types = None + if use_admin_priviledge: + # Administrators can access all events. + types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) + else: + types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) + event_id_to_state = await storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) - ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( - AccountDataTypes.IGNORED_USER_LIST, user_id - ) + ignore_dict_content = None + if not use_admin_priviledge: + ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( + AccountDataTypes.IGNORED_USER_LIST, user_id + ) ignore_list = frozenset() if ignore_dict_content: @@ -183,10 +195,12 @@ async def filter_events_for_client( if old_priority < new_priority: visibility = prev_visibility + membership = None + if use_admin_priviledge: + membership = Membership.JOIN # likewise, if the event is the user's own membership event, use # the 'most joined' membership - membership = None - if event.type == EventTypes.Member and event.state_key == user_id: + elif event.type == EventTypes.Member and event.state_key == user_id: membership = event.content.get("membership", None) if membership not in MEMBERSHIP_PRIORITY: membership = "leave" diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index a0f32c5512..7e89eb4793 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1430,6 +1430,54 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + def test_context(self): + """ + Test that, as admin, we can find the context of an event without having joined the room. + """ + + # Create a room. We're not part of it. + user_id = self.register_user("test", "test") + user_tok = self.login("test", "test") + room_id = self.helper.create_room_as(user_id, tok=user_tok) + + # Populate the room with events. + events = [] + for i in range(30): + events.append( + self.helper.send_event( + room_id, "com.example.test", content={"index": i}, tok=user_tok + ) + ) + + # Now let's fetch the context for this room. + midway = (len(events) - 1) // 2 + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/context/%s" + % (room_id, events[midway]["event_id"]), + access_token=self.admin_user_tok, + ) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals( + channel.json_body["event"]["event_id"], events[midway]["event_id"] + ) + + for i, found_event in enumerate(channel.json_body["events_before"]): + for j, posted_event in enumerate(events): + if found_event["event_id"] == posted_event["event_id"]: + self.assertTrue(j < midway) + break + else: + self.fail("Event %s from events_before not found" % j) + + for i, found_event in enumerate(channel.json_body["events_after"]): + for j, posted_event in enumerate(events): + if found_event["event_id"] == posted_event["event_id"]: + self.assertTrue(j > midway) + break + else: + self.fail("Event %s from events_after not found" % j) + class MakeRoomAdminTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From fe52dae6bd8508bfb79ba5e534a384a9e5e96b5d Mon Sep 17 00:00:00 2001 From: David Teller Date: Fri, 22 Jan 2021 14:41:20 +0100 Subject: FIXUP: Documenting /_synapse/admin/v1/rooms//context/ --- docs/admin_api/rooms.md | 119 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index f34cec1ff7..6d737c5be1 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -10,6 +10,7 @@ * [Undoing room shutdowns](#undoing-room-shutdowns) - [Make Room Admin API](#make-room-admin-api) - [Forward Extremities Admin API](#forward-extremities-admin-api) +- [Event Context API](#event-context-api) # List Room API @@ -564,3 +565,121 @@ that were deleted. "deleted": 1 } ``` + +# Event Context API + +This API lets a client find the context of an event. This is designed primarily to investigate abuse reports. + +``` +GET /_synapse/admin/v1/rooms//context/ +``` + +This API mimmicks [GET /_matrix/client/r0/rooms/{roomId}/context/{eventId}](https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-rooms-roomid-context-eventid). Please refer to the link for all details on parameters and reseponse. + +Example response: + +```json +{ + "end": "t29-57_2_0_2", + "events_after": [ + { + "content": { + "body": "This is an example text message", + "msgtype": "m.text", + "format": "org.matrix.custom.html", + "formatted_body": "This is an example text message" + }, + "type": "m.room.message", + "event_id": "$143273582443PhrSn:example.org", + "room_id": "!636q39766251:example.com", + "sender": "@example:example.org", + "origin_server_ts": 1432735824653, + "unsigned": { + "age": 1234 + } + } + ], + "event": { + "content": { + "body": "filename.jpg", + "info": { + "h": 398, + "w": 394, + "mimetype": "image/jpeg", + "size": 31037 + }, + "url": "mxc://example.org/JWEIFJgwEIhweiWJE", + "msgtype": "m.image" + }, + "type": "m.room.message", + "event_id": "$f3h4d129462ha:example.com", + "room_id": "!636q39766251:example.com", + "sender": "@example:example.org", + "origin_server_ts": 1432735824653, + "unsigned": { + "age": 1234 + } + }, + "events_before": [ + { + "content": { + "body": "something-important.doc", + "filename": "something-important.doc", + "info": { + "mimetype": "application/msword", + "size": 46144 + }, + "msgtype": "m.file", + "url": "mxc://example.org/FHyPlCeYUSFFxlgbQYZmoEoe" + }, + "type": "m.room.message", + "event_id": "$143273582443PhrSn:example.org", + "room_id": "!636q39766251:example.com", + "sender": "@example:example.org", + "origin_server_ts": 1432735824653, + "unsigned": { + "age": 1234 + } + } + ], + "start": "t27-54_2_0_2", + "state": [ + { + "content": { + "creator": "@example:example.org", + "room_version": "1", + "m.federate": true, + "predecessor": { + "event_id": "$something:example.org", + "room_id": "!oldroom:example.org" + } + }, + "type": "m.room.create", + "event_id": "$143273582443PhrSn:example.org", + "room_id": "!636q39766251:example.com", + "sender": "@example:example.org", + "origin_server_ts": 1432735824653, + "unsigned": { + "age": 1234 + }, + "state_key": "" + }, + { + "content": { + "membership": "join", + "avatar_url": "mxc://example.org/SEsfnsuifSDFSSEF", + "displayname": "Alice Margatroid" + }, + "type": "m.room.member", + "event_id": "$143273582443PhrSn:example.org", + "room_id": "!636q39766251:example.com", + "sender": "@example:example.org", + "origin_server_ts": 1432735824653, + "unsigned": { + "age": 1234 + }, + "state_key": "@alice:example.org" + } + ] +} +``` -- cgit 1.5.1 From de7f049527c64470a16b2633862a1f1b8f0da9c2 Mon Sep 17 00:00:00 2001 From: David Teller Date: Mon, 25 Jan 2021 18:02:35 +0100 Subject: FIXUP: Don't filter events at all for admin/v1/rooms/.../context/... --- synapse/handlers/room.py | 10 ++++------ synapse/visibility.py | 25 ++++++------------------- 2 files changed, 10 insertions(+), 25 deletions(-) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 15e90fdd4e..b02056dc63 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -52,7 +52,7 @@ from synapse.types import ( create_requester, ) from synapse.util import stringutils -from synapse.util.async_helpers import Linearizer +from synapse.util.async_helpers import Linearizer, maybe_awaitable from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client @@ -1034,12 +1034,10 @@ class RoomContextHandler: is_peeking = user.to_string() not in users def filter_evts(events): + if use_admin_priviledge: + return maybe_awaitable(events) return filter_events_for_client( - self.storage, - user.to_string(), - events, - is_peeking=is_peeking, - use_admin_priviledge=use_admin_priviledge, + self.storage, user.to_string(), events, is_peeking=is_peeking ) event = await self.store.get_event( diff --git a/synapse/visibility.py b/synapse/visibility.py index 80c304562d..2109e2f37f 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -53,7 +53,6 @@ async def filter_events_for_client( is_peeking=False, always_include_ids=frozenset(), filter_send_to_client=True, - use_admin_priviledge=False, ): """ Check which events a user is allowed to see. If the user can see the event but its @@ -72,9 +71,6 @@ async def filter_events_for_client( filter_send_to_client (bool): Whether we're checking an event that's going to be sent to a client. This might not always be the case since this function can also be called to check whether a user can see the state at a given point. - use_admin_priviledge: if `True`, return all events, regardless - of whether `user` has access to them. To be used **ONLY** - from the admin API. Returns: list[synapse.events.EventBase] @@ -83,23 +79,16 @@ async def filter_events_for_client( # to clients. events = [e for e in events if not e.internal_metadata.is_soft_failed()] - types = None - if use_admin_priviledge: - # Administrators can access all events. - types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) - else: - types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) + types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) event_id_to_state = await storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) - ignore_dict_content = None - if not use_admin_priviledge: - ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( - AccountDataTypes.IGNORED_USER_LIST, user_id - ) + ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( + AccountDataTypes.IGNORED_USER_LIST, user_id + ) ignore_list = frozenset() if ignore_dict_content: @@ -195,12 +184,10 @@ async def filter_events_for_client( if old_priority < new_priority: visibility = prev_visibility - membership = None - if use_admin_priviledge: - membership = Membership.JOIN # likewise, if the event is the user's own membership event, use # the 'most joined' membership - elif event.type == EventTypes.Member and event.state_key == user_id: + membership = None + if event.type == EventTypes.Member and event.state_key == user_id: membership = event.content.get("membership", None) if membership not in MEMBERSHIP_PRIORITY: membership = "leave" -- cgit 1.5.1 From b859919acc3ad6c61ba26a3c9b1e36c75a30c3fe Mon Sep 17 00:00:00 2001 From: David Teller Date: Thu, 28 Jan 2021 12:18:07 +0100 Subject: FIXUP: Now testing that the user is admin! --- changelog.d/9150.feature | 5 +---- synapse/rest/admin/rooms.py | 3 ++- tests/rest/admin/test_room.py | 36 +++++++++++++++++++++++++++++++++++- 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/changelog.d/9150.feature b/changelog.d/9150.feature index 86c4fd3d72..48a8148dee 100644 --- a/changelog.d/9150.feature +++ b/changelog.d/9150.feature @@ -1,4 +1 @@ -New API /_synapse/admin/rooms/{roomId}/context/{eventId} - -This API mirrors /_matrix/client/r0/rooms/{roomId}/context/{eventId} but lets administrators -inspect rooms. Designed to annotate abuse reports with context. +New API /_synapse/admin/rooms/{roomId}/context/{eventId}. diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 6539655289..4393197549 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -578,7 +578,8 @@ class RoomEventContextServlet(RestServlet): self.auth = hs.get_auth() async def on_GET(self, request, room_id, event_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=False) + await assert_user_is_admin(self.auth, requester.user) limit = parse_integer(request, "limit", default=10) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 7e89eb4793..fd201993d3 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1430,7 +1430,41 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - def test_context(self): + def test_context_as_non_admin(self): + """ + Test that, without being admin, one cannot use the context admin API + """ + # Create a room. + user_id = self.register_user("test", "test") + user_tok = self.login("test", "test") + + self.register_user("test_2", "test") + user_tok_2 = self.login("test_2", "test") + + room_id = self.helper.create_room_as(user_id, tok=user_tok) + + # Populate the room with events. + events = [] + for i in range(30): + events.append( + self.helper.send_event( + room_id, "com.example.test", content={"index": i}, tok=user_tok + ) + ) + + # Now attempt to find the context using the admin API without being admin. + midway = (len(events) - 1) // 2 + for tok in [user_tok, user_tok_2]: + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/context/%s" + % (room_id, events[midway]["event_id"]), + access_token=tok, + ) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_context_as_admin(self): """ Test that, as admin, we can find the context of an event without having joined the room. """ -- cgit 1.5.1 From a7648696231f3df7ec1f324a15dc308fbee2a985 Mon Sep 17 00:00:00 2001 From: David Teller Date: Thu, 28 Jan 2021 12:19:39 +0100 Subject: FIXUP: Doc --- synapse/rest/admin/rooms.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 4393197549..50df60e029 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -568,6 +568,12 @@ class ForwardExtremitiesRestServlet(RestServlet): return 200, {"count": len(extremities), "results": extremities} class RoomEventContextServlet(RestServlet): + """ + Provide the context for an event. + This API is designed to be used when system administrators wish to look at + an abuse report and understand what happened during and immediately prior + to this event. + """ PATTERNS = admin_patterns("/rooms/(?P[^/]*)/context/(?P[^/]*)$") def __init__(self, hs): -- cgit 1.5.1 From b755f60ce2fb651356d8d588dcd7864b022023cd Mon Sep 17 00:00:00 2001 From: David Teller Date: Thu, 28 Jan 2021 12:23:19 +0100 Subject: FIXUP: Removing awaitable --- synapse/handlers/room.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b02056dc63..c103488076 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1033,10 +1033,10 @@ class RoomContextHandler: users = await self.store.get_users_in_room(room_id) is_peeking = user.to_string() not in users - def filter_evts(events): + async def filter_evts(events): if use_admin_priviledge: - return maybe_awaitable(events) - return filter_events_for_client( + return events + return await filter_events_for_client( self.storage, user.to_string(), events, is_peeking=is_peeking ) -- cgit 1.5.1 From 93f84e037349cb3efddd8df5adf22512530a295c Mon Sep 17 00:00:00 2001 From: David Teller Date: Thu, 28 Jan 2021 12:27:30 +0100 Subject: FIXUP: Making get_event_context a bit more paranoid --- synapse/handlers/room.py | 10 ++++++++-- synapse/rest/admin/rooms.py | 2 +- synapse/rest/client/v1/room.py | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index c103488076..e039cea024 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -38,6 +38,7 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents +from synapse.rest.admin._base import assert_user_is_admin from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, @@ -997,13 +998,14 @@ class RoomCreationHandler(BaseHandler): class RoomContextHandler: def __init__(self, hs: "HomeServer"): self.hs = hs + self.auth = hs.get_auth() self.store = hs.get_datastore() self.storage = hs.get_storage() self.state_store = self.storage.state async def get_event_context( self, - user: UserID, + requester: Requester, room_id: str, event_id: str, limit: int, @@ -1014,7 +1016,7 @@ class RoomContextHandler: in a room. Args: - user + requester room_id event_id limit: The maximum number of events to return in total @@ -1027,6 +1029,10 @@ class RoomContextHandler: Returns: dict, or None if the event isn't found """ + user = requester.user + if use_admin_priviledge: + await assert_user_is_admin(self.auth, requester.user) + before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 50df60e029..641e325581 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -600,7 +600,7 @@ class RoomEventContextServlet(RestServlet): event_filter = None results = await self.room_context_handler.get_event_context( - requester.user, + requester, room_id, event_id, limit, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f95627ee61..90fd98c53e 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -650,7 +650,7 @@ class RoomEventContextServlet(RestServlet): event_filter = None results = await self.room_context_handler.get_event_context( - requester.user, room_id, event_id, limit, event_filter + requester, room_id, event_id, limit, event_filter ) if not results: -- cgit 1.5.1 From 31d072aea0a37ad5408995359b89080b5280f57d Mon Sep 17 00:00:00 2001 From: David Teller Date: Thu, 28 Jan 2021 16:53:40 +0100 Subject: FIXUP: linter --- synapse/handlers/room.py | 2 +- synapse/rest/admin/rooms.py | 5 +++++ tests/rest/admin/test_room.py | 4 +++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e039cea024..99fd063084 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -53,7 +53,7 @@ from synapse.types import ( create_requester, ) from synapse.util import stringutils -from synapse.util.async_helpers import Linearizer, maybe_awaitable +from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 641e325581..c673ff07e7 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -15,9 +15,11 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple +from urllib import parse as urlparse from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.filtering import Filter from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -33,6 +35,7 @@ from synapse.rest.admin._base import ( ) from synapse.storage.databases.main.room import RoomSortOrder from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester +from synapse.util import json_decoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -567,6 +570,7 @@ class ForwardExtremitiesRestServlet(RestServlet): extremities = await self.store.get_forward_extremities_for_room(room_id) return 200, {"count": len(extremities), "results": extremities} + class RoomEventContextServlet(RestServlet): """ Provide the context for an event. @@ -574,6 +578,7 @@ class RoomEventContextServlet(RestServlet): an abuse report and understand what happened during and immediately prior to this event. """ + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/context/(?P[^/]*)$") def __init__(self, hs): diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index fd201993d3..f4f5569a89 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1461,7 +1461,9 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals( + 403, int(channel.result["code"]), msg=channel.result["body"] + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_context_as_admin(self): -- cgit 1.5.1 From d882fbca388692f475673007df4f7ac394446b46 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Fri, 5 Feb 2021 21:39:19 +0100 Subject: Update type hints for Cursor to match PEP 249. (#9299) --- changelog.d/9299.misc | 1 + synapse/storage/database.py | 14 +++++++++----- synapse/storage/prepare_database.py | 4 ++-- synapse/storage/types.py | 37 +++++++++++++++++++++++++++++-------- synapse/storage/util/sequence.py | 8 ++++++-- 5 files changed, 47 insertions(+), 17 deletions(-) create mode 100644 changelog.d/9299.misc diff --git a/changelog.d/9299.misc b/changelog.d/9299.misc new file mode 100644 index 0000000000..c883a677ed --- /dev/null +++ b/changelog.d/9299.misc @@ -0,0 +1 @@ +Update the `Cursor` type hints to better match PEP 249. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index d2ba4bd2fc..ae4bf1a54f 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -158,8 +158,8 @@ class LoggingDatabaseConnection: def commit(self) -> None: self.conn.commit() - def rollback(self, *args, **kwargs) -> None: - self.conn.rollback(*args, **kwargs) + def rollback(self) -> None: + self.conn.rollback() def __enter__(self) -> "Connection": self.conn.__enter__() @@ -244,12 +244,15 @@ class LoggingTransaction: assert self.exception_callbacks is not None self.exception_callbacks.append((callback, args, kwargs)) + def fetchone(self) -> Optional[Tuple]: + return self.txn.fetchone() + + def fetchmany(self, size: Optional[int] = None) -> List[Tuple]: + return self.txn.fetchmany(size=size) + def fetchall(self) -> List[Tuple]: return self.txn.fetchall() - def fetchone(self) -> Tuple: - return self.txn.fetchone() - def __iter__(self) -> Iterator[Tuple]: return self.txn.__iter__() @@ -754,6 +757,7 @@ class DatabasePool: Returns: A list of dicts where the key is the column header. """ + assert cursor.description is not None, "cursor.description was None" col_headers = [intern(str(column[0])) for column in cursor.description] results = [dict(zip(col_headers, row)) for row in cursor] return results diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 566ea19bae..28bb2eb662 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -619,9 +619,9 @@ def _get_or_create_schema_state( txn.execute("SELECT version, upgraded FROM schema_version") row = txn.fetchone() - current_version = int(row[0]) if row else None - if current_version: + if row is not None: + current_version = int(row[0]) txn.execute( "SELECT file FROM applied_schema_deltas WHERE version >= ?", (current_version,), diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 9cadcba18f..17291c9d5e 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Iterator, List, Optional, Tuple +from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union from typing_extensions import Protocol @@ -20,23 +20,44 @@ from typing_extensions import Protocol Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 """ +_Parameters = Union[Sequence[Any], Mapping[str, Any]] + class Cursor(Protocol): - def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any: + def execute(self, sql: str, parameters: _Parameters = ...) -> Any: ... - def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any: + def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any: ... - def fetchall(self) -> List[Tuple]: + def fetchone(self) -> Optional[Tuple]: + ... + + def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]: ... - def fetchone(self) -> Tuple: + def fetchall(self) -> List[Tuple]: ... @property - def description(self) -> Any: - return None + def description( + self, + ) -> Optional[ + Sequence[ + # Note that this is an approximate typing based on sqlite3 and other + # drivers, and may not be entirely accurate. + Tuple[ + str, + Optional[Any], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + ] + ] + ]: + ... @property def rowcount(self) -> int: @@ -59,7 +80,7 @@ class Connection(Protocol): def commit(self) -> None: ... - def rollback(self, *args, **kwargs) -> None: + def rollback(self) -> None: ... def __enter__(self) -> "Connection": diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 0ec4dc2918..e2b316a218 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -106,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator): def get_next_id_txn(self, txn: Cursor) -> int: txn.execute("SELECT nextval(?)", (self._sequence_name,)) - return txn.fetchone()[0] + fetch_res = txn.fetchone() + assert fetch_res is not None + return fetch_res[0] def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: txn.execute( @@ -147,7 +149,9 @@ class PostgresSequenceGenerator(SequenceGenerator): txn.execute( "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} ) - last_value, is_called = txn.fetchone() + fetch_res = txn.fetchone() + assert fetch_res is not None + last_value, is_called = fetch_res # If we have an associated stream check the stream_positions table. max_in_stream_positions = None -- cgit 1.5.1 From 0963d39ea66be95d44f8f4ae29fa6f33fc6740b4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 8 Feb 2021 12:33:30 -0500 Subject: Handle additional errors when previewing URLs. (#9333) * Handle the case of lxml not finding a document tree. * Parse the document encoding from the XML tag. --- changelog.d/9333.bugfix | 1 + synapse/rest/media/v1/preview_url_resource.py | 71 +++++++++++++----- tests/test_preview.py | 103 +++++++++++++++++++++++--- 3 files changed, 145 insertions(+), 30 deletions(-) create mode 100644 changelog.d/9333.bugfix diff --git a/changelog.d/9333.bugfix b/changelog.d/9333.bugfix new file mode 100644 index 0000000000..c34ba378c5 --- /dev/null +++ b/changelog.d/9333.bugfix @@ -0,0 +1 @@ +Fix additional errors when previewing URLs: "AttributeError 'NoneType' object has no attribute 'xpath'" and "ValueError: Unicode strings with encoding declaration are not supported. Please use bytes input or XML fragments without declaration.". diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index bf3be653aa..ae53b1d23f 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -58,7 +58,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I) +_charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I) +_xml_encoding_match = re.compile( + br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I +) _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) OG_TAG_NAME_MAXLEN = 50 @@ -300,24 +303,7 @@ class PreviewUrlResource(DirectServeJsonResource): with open(media_info["filename"], "rb") as file: body = file.read() - encoding = None - - # Let's try and figure out if it has an encoding set in a meta tag. - # Limit it to the first 1kb, since it ought to be in the meta tags - # at the top. - match = _charset_match.search(body[:1000]) - - # If we find a match, it should take precedence over the - # Content-Type header, so set it here. - if match: - encoding = match.group(1).decode("ascii") - - # If we don't find a match, we'll look at the HTTP Content-Type, and - # if that doesn't exist, we'll fall back to UTF-8. - if not encoding: - content_match = _content_type_match.match(media_info["media_type"]) - encoding = content_match.group(1) if content_match else "utf-8" - + encoding = get_html_media_encoding(body, media_info["media_type"]) og = decode_and_calc_og(body, media_info["uri"], encoding) # pre-cache the image for posterity @@ -689,6 +675,48 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") +def get_html_media_encoding(body: bytes, content_type: str) -> str: + """ + Get the encoding of the body based on the (presumably) HTML body or media_type. + + The precedence used for finding a character encoding is: + + 1. meta tag with a charset declared. + 2. The XML document's character encoding attribute. + 3. The Content-Type header. + 4. Fallback to UTF-8. + + Args: + body: The HTML document, as bytes. + content_type: The Content-Type header. + + Returns: + The character encoding of the body, as a string. + """ + # Limit searches to the first 1kb, since it ought to be at the top. + body_start = body[:1024] + + # Let's try and figure out if it has an encoding set in a meta tag. + match = _charset_match.search(body_start) + if match: + return match.group(1).decode("ascii") + + # TODO Support + + # If we didn't find a match, see if it an XML document with an encoding. + match = _xml_encoding_match.match(body_start) + if match: + return match.group(1).decode("ascii") + + # If we don't find a match, we'll look at the HTTP Content-Type, and + # if that doesn't exist, we'll fall back to UTF-8. + content_match = _content_type_match.match(content_type) + if content_match: + return content_match.group(1) + + return "utf-8" + + def decode_and_calc_og( body: bytes, media_uri: str, request_encoding: Optional[str] = None ) -> Dict[str, Optional[str]]: @@ -725,6 +753,11 @@ def decode_and_calc_og( def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]: # Attempt to parse the body. If this fails, log and return no metadata. tree = etree.fromstring(body_attempt, parser) + + # The data was successfully parsed, but no tree was found. + if tree is None: + return {} + return _calc_og(tree, media_uri) # Attempt to parse the body. If this fails, log and return no metadata. diff --git a/tests/test_preview.py b/tests/test_preview.py index 0c6cbbd921..ea83299918 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -15,6 +15,7 @@ from synapse.rest.media.v1.preview_url_resource import ( decode_and_calc_og, + get_html_media_encoding, summarize_paragraphs, ) @@ -26,7 +27,7 @@ except ImportError: lxml = None -class PreviewTestCase(unittest.TestCase): +class SummarizeTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" @@ -144,12 +145,12 @@ class PreviewTestCase(unittest.TestCase): ) -class PreviewUrlTestCase(unittest.TestCase): +class CalcOgTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" def test_simple(self): - html = """ + html = b""" Foo @@ -163,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment(self): - html = """ + html = b""" Foo @@ -178,7 +179,7 @@ class PreviewUrlTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment2(self): - html = """ + html = b""" Foo @@ -202,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase): ) def test_script(self): - html = """ + html = b""" Foo @@ -217,7 +218,7 @@ class PreviewUrlTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_missing_title(self): - html = """ + html = b""" Some text. @@ -230,7 +231,7 @@ class PreviewUrlTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) def test_h1_as_title(self): - html = """ + html = b""" @@ -244,7 +245,7 @@ class PreviewUrlTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) def test_missing_title_and_broken_h1(self): - html = """ + html = b"""

@@ -258,13 +259,20 @@ class PreviewUrlTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) def test_empty(self): - html = "" + """Test a body with no data in it.""" + html = b"" + og = decode_and_calc_og(html, "http://example.com/test.html") + self.assertEqual(og, {}) + + def test_no_tree(self): + """A valid body with no tree in it.""" + html = b"\x00" og = decode_and_calc_og(html, "http://example.com/test.html") self.assertEqual(og, {}) def test_invalid_encoding(self): """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" - html = """ + html = b""" Foo @@ -290,3 +298,76 @@ class PreviewUrlTestCase(unittest.TestCase): """ og = decode_and_calc_og(html, "http://example.com/test.html") self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) + + +class MediaEncodingTestCase(unittest.TestCase): + def test_meta_charset(self): + """A character encoding is found via the meta tag.""" + encoding = get_html_media_encoding( + b""" + + + + + """, + "text/html", + ) + self.assertEqual(encoding, "ascii") + + # A less well-formed version. + encoding = get_html_media_encoding( + b""" + + < meta charset = ascii> + + + """, + "text/html", + ) + self.assertEqual(encoding, "ascii") + + def test_xml_encoding(self): + """A character encoding is found via the meta tag.""" + encoding = get_html_media_encoding( + b""" + + + + """, + "text/html", + ) + self.assertEqual(encoding, "ascii") + + def test_meta_xml_encoding(self): + """Meta tags take precedence over XML encoding.""" + encoding = get_html_media_encoding( + b""" + + + + + + """, + "text/html", + ) + self.assertEqual(encoding, "UTF-16") + + def test_content_type(self): + """A character encoding is found via the Content-Type header.""" + # Test a few variations of the header. + headers = ( + 'text/html; charset="ascii";', + "text/html;charset=ascii;", + 'text/html; charset="ascii"', + "text/html; charset=ascii", + 'text/html; charset="ascii;', + 'text/html; charset=ascii";', + ) + for header in headers: + encoding = get_html_media_encoding(b"", header) + self.assertEqual(encoding, "ascii") + + def test_fallback(self): + """A character encoding cannot be found in the body or header.""" + encoding = get_html_media_encoding(b"", "text/html") + self.assertEqual(encoding, "utf-8") -- cgit 1.5.1 From 3f58fc848d0002de4605bed91603a1f9f245d128 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 8 Feb 2021 13:59:54 -0500 Subject: Type hints and validation improvements. (#9321) * Adds type hints to the groups servlet and stringutils code. * Assert the maximum length of some input values for spec compliance. --- changelog.d/9321.bugfix | 1 + synapse/groups/groups_server.py | 25 ++++- synapse/rest/client/v2_alpha/groups.py | 179 ++++++++++++++++++++----------- synapse/rest/client/v2_alpha/register.py | 2 + synapse/server.py | 16 ++- synapse/util/stringutils.py | 33 +++--- 6 files changed, 177 insertions(+), 79 deletions(-) create mode 100644 changelog.d/9321.bugfix diff --git a/changelog.d/9321.bugfix b/changelog.d/9321.bugfix new file mode 100644 index 0000000000..52eed80969 --- /dev/null +++ b/changelog.d/9321.bugfix @@ -0,0 +1 @@ +Assert a maximum length for the `client_secret` parameter for spec compliance. diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 0d042cbfac..76bf52ea23 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -18,6 +18,7 @@ import logging from synapse.api.errors import Codes, SynapseError +from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.types import GroupID, RoomID, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute @@ -32,6 +33,11 @@ logger = logging.getLogger(__name__) # TODO: Flairs +# Note that the maximum lengths are somewhat arbitrary. +MAX_SHORT_DESC_LEN = 1000 +MAX_LONG_DESC_LEN = 10000 + + class GroupsServerWorkerHandler: def __init__(self, hs): self.hs = hs @@ -508,11 +514,26 @@ class GroupsServerHandler(GroupsServerWorkerHandler): ) profile = {} - for keyname in ("name", "avatar_url", "short_description", "long_description"): + for keyname, max_length in ( + ("name", MAX_DISPLAYNAME_LEN), + ("avatar_url", MAX_AVATAR_URL_LEN), + ("short_description", MAX_SHORT_DESC_LEN), + ("long_description", MAX_LONG_DESC_LEN), + ): if keyname in content: value = content[keyname] if not isinstance(value, str): - raise SynapseError(400, "%r value is not a string" % (keyname,)) + raise SynapseError( + 400, + "%r value is not a string" % (keyname,), + errcode=Codes.INVALID_PARAM, + ) + if len(value) > max_length: + raise SynapseError( + 400, + "Invalid %s parameter" % (keyname,), + errcode=Codes.INVALID_PARAM, + ) profile[keyname] = value await self.store.update_group_profile(group_id, profile) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 5b5da71815..4fe712b30c 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -16,13 +16,24 @@ import logging from functools import wraps +from typing import TYPE_CHECKING, Tuple + +from twisted.web.http import Request from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import GroupID +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.types import GroupID, JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -33,7 +44,7 @@ def _validate_group_id(f): """ @wraps(f) - def wrapper(self, request, group_id, *args, **kwargs): + def wrapper(self, request: Request, group_id: str, *args, **kwargs): if not GroupID.is_valid(group_id): raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) @@ -48,14 +59,14 @@ class GroupServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/profile$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -66,11 +77,15 @@ class GroupServlet(RestServlet): return 200, group_description @_validate_group_id - async def on_POST(self, request, group_id): + async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert_params_in_dict( + content, ("name", "avatar_url", "short_description", "long_description") + ) + assert isinstance(self.groups_handler, GroupsLocalHandler) await self.groups_handler.update_group_profile( group_id, requester_user_id, content ) @@ -84,14 +99,14 @@ class GroupSummaryServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/summary$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -116,18 +131,21 @@ class GroupSummaryRoomsCatServlet(RestServlet): "/rooms/(?P[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, category_id, room_id): + async def on_PUT( + self, request: Request, group_id: str, category_id: str, room_id: str + ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.update_group_summary_room( group_id, requester_user_id, @@ -139,10 +157,13 @@ class GroupSummaryRoomsCatServlet(RestServlet): return 200, resp @_validate_group_id - async def on_DELETE(self, request, group_id, category_id, room_id): + async def on_DELETE( + self, request: Request, group_id: str, category_id: str, room_id: str + ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.delete_group_summary_room( group_id, requester_user_id, room_id=room_id, category_id=category_id ) @@ -158,14 +179,16 @@ class GroupCategoryServlet(RestServlet): "/groups/(?P[^/]*)/categories/(?P[^/]+)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id, category_id): + async def on_GET( + self, request: Request, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -176,11 +199,14 @@ class GroupCategoryServlet(RestServlet): return 200, category @_validate_group_id - async def on_PUT(self, request, group_id, category_id): + async def on_PUT( + self, request: Request, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.update_group_category( group_id, requester_user_id, category_id=category_id, content=content ) @@ -188,10 +214,13 @@ class GroupCategoryServlet(RestServlet): return 200, resp @_validate_group_id - async def on_DELETE(self, request, group_id, category_id): + async def on_DELETE( + self, request: Request, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.delete_group_category( group_id, requester_user_id, category_id=category_id ) @@ -205,14 +234,14 @@ class GroupCategoriesServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/categories/$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -229,14 +258,16 @@ class GroupRoleServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/(?P[^/]+)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id, role_id): + async def on_GET( + self, request: Request, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -247,11 +278,14 @@ class GroupRoleServlet(RestServlet): return 200, category @_validate_group_id - async def on_PUT(self, request, group_id, role_id): + async def on_PUT( + self, request: Request, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.update_group_role( group_id, requester_user_id, role_id=role_id, content=content ) @@ -259,10 +293,13 @@ class GroupRoleServlet(RestServlet): return 200, resp @_validate_group_id - async def on_DELETE(self, request, group_id, role_id): + async def on_DELETE( + self, request: Request, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.delete_group_role( group_id, requester_user_id, role_id=role_id ) @@ -276,14 +313,14 @@ class GroupRolesServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -308,18 +345,21 @@ class GroupSummaryUsersRoleServlet(RestServlet): "/users/(?P[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, role_id, user_id): + async def on_PUT( + self, request: Request, group_id: str, role_id: str, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.update_group_summary_user( group_id, requester_user_id, @@ -331,10 +371,13 @@ class GroupSummaryUsersRoleServlet(RestServlet): return 200, resp @_validate_group_id - async def on_DELETE(self, request, group_id, role_id, user_id): + async def on_DELETE( + self, request: Request, group_id: str, role_id: str, user_id: str + ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.delete_group_summary_user( group_id, requester_user_id, user_id=user_id, role_id=role_id ) @@ -348,14 +391,14 @@ class GroupRoomServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/rooms$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -372,14 +415,14 @@ class GroupUsersServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/users$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -396,14 +439,14 @@ class GroupInvitedUsersServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/invited_users$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -420,18 +463,19 @@ class GroupSettingJoinPolicyServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.set_group_join_policy( group_id, requester_user_id, content ) @@ -445,14 +489,14 @@ class GroupCreateServlet(RestServlet): PATTERNS = client_patterns("/create_group$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() self.server_name = hs.hostname - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -461,6 +505,7 @@ class GroupCreateServlet(RestServlet): localpart = content.pop("localpart") group_id = GroupID(localpart, self.server_name).to_string() + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.create_group( group_id, requester_user_id, content ) @@ -476,18 +521,21 @@ class GroupAdminRoomsServlet(RestServlet): "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, room_id): + async def on_PUT( + self, request: Request, group_id: str, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.add_room_to_group( group_id, requester_user_id, room_id, content ) @@ -495,10 +543,13 @@ class GroupAdminRoomsServlet(RestServlet): return 200, result @_validate_group_id - async def on_DELETE(self, request, group_id, room_id): + async def on_DELETE( + self, request: Request, group_id: str, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.remove_room_from_group( group_id, requester_user_id, room_id ) @@ -515,18 +566,21 @@ class GroupAdminRoomsConfigServlet(RestServlet): "/config/(?P[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, room_id, config_key): + async def on_PUT( + self, request: Request, group_id: str, room_id: str, config_key: str + ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.update_room_in_group( group_id, requester_user_id, room_id, config_key, content ) @@ -542,7 +596,7 @@ class GroupAdminUsersInviteServlet(RestServlet): "/groups/(?P[^/]*)/admin/users/invite/(?P[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() @@ -551,12 +605,13 @@ class GroupAdminUsersInviteServlet(RestServlet): self.is_mine_id = hs.is_mine_id @_validate_group_id - async def on_PUT(self, request, group_id, user_id): + async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) config = content.get("config", {}) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.invite( group_id, user_id, requester_user_id, config ) @@ -572,18 +627,19 @@ class GroupAdminUsersKickServlet(RestServlet): "/groups/(?P[^/]*)/admin/users/remove/(?P[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, user_id): + async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) @@ -597,18 +653,19 @@ class GroupSelfLeaveServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/self/leave$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.remove_user_from_group( group_id, requester_user_id, requester_user_id, content ) @@ -622,18 +679,19 @@ class GroupSelfJoinServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/self/join$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.join_group( group_id, requester_user_id, content ) @@ -647,18 +705,19 @@ class GroupSelfAcceptInviteServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/self/accept_invite$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.accept_invite( group_id, requester_user_id, content ) @@ -672,14 +731,14 @@ class GroupSelfUpdatePublicityServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/self/update_publicity$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -696,14 +755,14 @@ class PublicisedGroupsForUserServlet(RestServlet): PATTERNS = client_patterns("/publicised_groups/(?P[^/]*)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - async def on_GET(self, request, user_id): + async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) result = await self.groups_handler.get_publicised_groups_for_user(user_id) @@ -717,14 +776,14 @@ class PublicisedGroupsForUsersServlet(RestServlet): PATTERNS = client_patterns("/publicised_groups$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) @@ -741,13 +800,13 @@ class GroupsForUserServlet(RestServlet): PATTERNS = client_patterns("/joined_groups$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - async def on_GET(self, request): + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -756,7 +815,7 @@ class GroupsForUserServlet(RestServlet): return 200, result -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): GroupServlet(hs).register(http_server) GroupSummaryServlet(hs).register(http_server) GroupInvitedUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 10e1891174..e3d322f2ac 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -193,6 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): body, ["client_secret", "country", "phone_number", "send_attempt"] ) client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) country = body["country"] phone_number = body["phone_number"] send_attempt = body["send_attempt"] @@ -293,6 +294,7 @@ class RegistrationSubmitTokenServlet(RestServlet): sid = parse_string(request, "sid", required=True) client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) token = parse_string(request, "token", required=True) # Attempt to validate a 3PID session diff --git a/synapse/server.py b/synapse/server.py index 9bdd3177d7..6b3892e3cd 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,7 +25,17 @@ import abc import functools import logging import os -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + TypeVar, + Union, + cast, +) import twisted.internet.base import twisted.internet.tcp @@ -588,7 +598,9 @@ class HomeServer(metaclass=abc.ABCMeta): return UserDirectoryHandler(self) @cache_in_self - def get_groups_local_handler(self): + def get_groups_local_handler( + self, + ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]: if self.config.worker_app: return GroupsLocalWorkerHandler(self) else: diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index f8038bf861..9ce7873ab5 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -25,7 +25,7 @@ from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken -client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") +CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$") # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically @@ -42,28 +42,31 @@ MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") rand = random.SystemRandom() -def random_string(length): +def random_string(length: int) -> str: return "".join(rand.choice(string.ascii_letters) for _ in range(length)) -def random_string_with_symbols(length): +def random_string_with_symbols(length: int) -> str: return "".join(rand.choice(_string_with_symbols) for _ in range(length)) -def is_ascii(s): - if isinstance(s, bytes): - try: - s.decode("ascii").encode("ascii") - except UnicodeDecodeError: - return False - except UnicodeEncodeError: - return False - return True +def is_ascii(s: bytes) -> bool: + try: + s.decode("ascii").encode("ascii") + except UnicodeDecodeError: + return False + except UnicodeEncodeError: + return False + return True -def assert_valid_client_secret(client_secret): - """Validate that a given string matches the client_secret regex defined by the spec""" - if client_secret_regex.match(client_secret) is None: +def assert_valid_client_secret(client_secret: str) -> None: + """Validate that a given string matches the client_secret defined by the spec""" + if ( + len(client_secret) <= 0 + or len(client_secret) > 255 + or CLIENT_SECRET_REGEX.match(client_secret) is None + ): raise SynapseError( 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM ) -- cgit 1.5.1