From 3bbe532abb7bfc41467597731ac1a18c0331f539 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 13 Oct 2022 08:02:11 -0400 Subject: Add an API for listing threads in a room. (#13394) Implement the /threads endpoint from MSC3856. This is currently unstable and behind an experimental configuration flag. It includes a background update to backfill data, results from the /threads endpoint will be partial until that finishes. --- tests/rest/client/test_relations.py | 151 ++++++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) (limited to 'tests/rest') diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 988cdb746d..d595295e2c 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1707,3 +1707,154 @@ class RelationRedactionTestCase(BaseRelationsTestCase): relations[RelationTypes.THREAD]["latest_event"]["event_id"], related_event_id, ) + + +class ThreadsTestCase(BaseRelationsTestCase): + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_threads(self) -> None: + """Create threads and ensure the ordering is due to their latest event.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1]) + + # Update the first thread, the ordering should swap. + self._send_relation(RelationTypes.THREAD, "m.room.test") + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1, thread_2]) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_pagination(self) -> None: + """Create threads and paginate through them.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2]) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + next_batch = channel.json_body.get("next_batch") + self.assertIsInstance(next_batch, str, channel.json_body) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1], channel.json_body) + + self.assertNotIn("next_batch", channel.json_body, channel.json_body) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_include(self) -> None: + """Filtering threads to all or participated in should work.""" + # Thread 1 has the user as the root event. + thread_1 = self.parent_id + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + + # Thread 2 has the user replying. + res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token) + thread_2 = res["event_id"] + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Thread 3 has the user not participating in. + res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token) + thread_3 = res["event_id"] + self._send_relation( + RelationTypes.THREAD, + "m.room.test", + access_token=self.user2_token, + parent_id=thread_3, + ) + + # All threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual( + thread_roots, [thread_3, thread_2, thread_1], channel.json_body + ) + + # Only participated threads. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_ignored_user(self) -> None: + """Events from ignored users should be ignored.""" + # Thread 1 has a reply from an ignored user. + thread_1 = self.parent_id + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + + # Thread 2 is created by an ignored user. + res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token) + thread_2 = res["event_id"] + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Ignore user2. + self.get_success( + self.store.add_account_data_for_user( + self.user_id, + AccountDataTypes.IGNORED_USER_LIST, + {"ignored_users": {self.user2_id: {}}}, + ) + ) + + # Only thread 1 is returned. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1], channel.json_body) -- cgit 1.4.1 From c3e4edb4d6ba33383bc056e3ff22b2d034d3e248 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 07:16:50 -0400 Subject: Stabilize the threads API. (#14175) Stabilize the threads API (MSC3856) by supporting (only) the v1 path for the endpoint. This also marks the API as safe for workers since it is a read-only API. --- changelog.d/13394.feature | 2 +- changelog.d/14175.feature | 1 + docker/configure_workers_and_start.py | 1 + docs/workers.md | 1 + synapse/config/experimental.py | 3 --- synapse/rest/client/relations.py | 9 ++----- tests/rest/client/test_relations.py | 47 +++++++++++++++++++++-------------- 7 files changed, 35 insertions(+), 29 deletions(-) create mode 100644 changelog.d/14175.feature (limited to 'tests/rest') diff --git a/changelog.d/13394.feature b/changelog.d/13394.feature index 68de079cf3..df3ce45a76 100644 --- a/changelog.d/13394.feature +++ b/changelog.d/13394.feature @@ -1 +1 @@ -Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. +Support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/changelog.d/14175.feature b/changelog.d/14175.feature new file mode 100644 index 0000000000..df3ce45a76 --- /dev/null +++ b/changelog.d/14175.feature @@ -0,0 +1 @@ +Support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 8e7f605b24..d708237f69 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -118,6 +118,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$", "^/_matrix/client/v1/rooms/.*/hierarchy$", "^/_matrix/client/(v1|unstable)/rooms/.*/relations/", + "^/_matrix/client/v1/rooms/.*/threads$", "^/_matrix/client/(api/v1|r0|v3|unstable)/login$", "^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$", "^/_matrix/client/(api/v1|r0|v3|unstable)/account/whoami$", diff --git a/docs/workers.md b/docs/workers.md index e8d6cbaf8b..c27b3f8bd5 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -204,6 +204,7 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ ^/_matrix/client/v1/rooms/.*/hierarchy$ ^/_matrix/client/(v1|unstable)/rooms/.*/relations/ + ^/_matrix/client/v1/rooms/.*/threads$ ^/_matrix/client/unstable/org.matrix.msc2716/rooms/.*/batch_send$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ ^/_matrix/client/(r0|v3|unstable)/account/3pid$ diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 1860006536..f44655516e 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -101,9 +101,6 @@ class ExperimentalConfig(Config): # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) - # MSC3856: Threads list API - self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False) - # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d1aa1947a5..9dd59196d9 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -82,11 +82,7 @@ class RelationPaginationServlet(RestServlet): class ThreadsServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P[^/]*)/threads" - ), - ) + PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P[^/]*)/threads"),) def __init__(self, hs: "HomeServer"): super().__init__() @@ -126,5 +122,4 @@ class ThreadsServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) - if hs.config.experimental.msc3856_enabled: - ThreadsServlet(hs).register(http_server) + ThreadsServlet(hs).register(http_server) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index d595295e2c..f5c1070b2c 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1710,7 +1710,15 @@ class RelationRedactionTestCase(BaseRelationsTestCase): class ThreadsTestCase(BaseRelationsTestCase): - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def _get_threads(self, body: JsonDict) -> List[Tuple[str, str]]: + return [ + ( + ev["event_id"], + ev["unsigned"]["m.relations"]["m.thread"]["latest_event"]["event_id"], + ) + for ev in body["chunk"] + ] + def test_threads(self) -> None: """Create threads and ensure the ordering is due to their latest event.""" # Create 2 threads. @@ -1718,32 +1726,37 @@ class ThreadsTestCase(BaseRelationsTestCase): res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) thread_2 = res["event_id"] - self._send_relation(RelationTypes.THREAD, "m.room.test") - self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + reply_1 = channel.json_body["event_id"] + channel = self._send_relation( + RelationTypes.THREAD, "m.room.test", parent_id=thread_2 + ) + reply_2 = channel.json_body["event_id"] # Request the threads in the room. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] - self.assertEqual(thread_roots, [thread_2, thread_1]) + threads = self._get_threads(channel.json_body) + self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)]) # Update the first thread, the ordering should swap. - self._send_relation(RelationTypes.THREAD, "m.room.test") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + reply_3 = channel.json_body["event_id"] channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] - self.assertEqual(thread_roots, [thread_1, thread_2]) + # Tuple of (thread ID, latest event ID) for each thread. + threads = self._get_threads(channel.json_body) + self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)]) - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) def test_pagination(self) -> None: """Create threads and paginate through them.""" # Create 2 threads. @@ -1757,7 +1770,7 @@ class ThreadsTestCase(BaseRelationsTestCase): # Request the threads in the room. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1", + f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -1771,7 +1784,7 @@ class ThreadsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}", + f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -1780,7 +1793,6 @@ class ThreadsTestCase(BaseRelationsTestCase): self.assertNotIn("next_batch", channel.json_body, channel.json_body) - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) def test_include(self) -> None: """Filtering threads to all or participated in should work.""" # Thread 1 has the user as the root event. @@ -1807,7 +1819,7 @@ class ThreadsTestCase(BaseRelationsTestCase): # All threads in the room. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -1819,14 +1831,13 @@ class ThreadsTestCase(BaseRelationsTestCase): # Only participated threads. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated", + f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) - @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) def test_ignored_user(self) -> None: """Events from ignored users should be ignored.""" # Thread 1 has a reply from an ignored user. @@ -1852,7 +1863,7 @@ class ThreadsTestCase(BaseRelationsTestCase): # Only thread 1 is returned. channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) -- cgit 1.4.1 From 126a15794c95002560709283640ad412636b29b8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 08:30:05 -0400 Subject: Do not allow a None-limit on PaginationConfig. (#14146) The callers either set a default limit or manually handle a None-limit later on (by setting a default value). Update the callers to always instantiate PaginationConfig with a default limit and then assume the limit is non-None. --- changelog.d/14146.removal | 1 + synapse/handlers/account_data.py | 2 +- synapse/handlers/initial_sync.py | 27 ++++----------------------- synapse/handlers/pagination.py | 5 ----- synapse/handlers/presence.py | 4 +++- synapse/handlers/receipts.py | 2 +- synapse/handlers/relations.py | 3 --- synapse/handlers/room.py | 2 +- synapse/handlers/typing.py | 2 +- synapse/rest/client/events.py | 4 +++- synapse/rest/client/initial_sync.py | 4 +++- synapse/rest/client/room.py | 4 +++- synapse/storage/databases/main/stream.py | 2 -- synapse/streams/__init__.py | 2 +- synapse/streams/config.py | 12 +++++------- tests/rest/client/test_typing.py | 3 ++- 16 files changed, 29 insertions(+), 50 deletions(-) create mode 100644 changelog.d/14146.removal (limited to 'tests/rest') diff --git a/changelog.d/14146.removal b/changelog.d/14146.removal new file mode 100644 index 0000000000..08fa752897 --- /dev/null +++ b/changelog.d/14146.removal @@ -0,0 +1 @@ +Remove the unstable identifier for [MSC3715](https://github.com/matrix-org/matrix-doc/pull/3715). diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 0478448b47..fc21d58001 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -225,7 +225,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 860c82c110..9c335e6863 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -57,13 +57,7 @@ class InitialSyncHandler: self.validator = EventValidator() self.snapshot_cache: ResponseCache[ Tuple[ - str, - Optional[StreamToken], - Optional[StreamToken], - str, - Optional[int], - bool, - bool, + str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() @@ -154,11 +148,6 @@ class InitialSyncHandler: public_room_ids = await self.store.get_public_room_ids() - if pagin_config.limit is not None: - limit = pagin_config.limit - else: - limit = 10 - serializer_options = SerializeEventConfig(as_client_event=as_client_event) async def handle_room(event: RoomsForUser) -> None: @@ -210,7 +199,7 @@ class InitialSyncHandler: run_in_background( self.store.get_recent_events_for_room, event.room_id, - limit=limit, + limit=pagin_config.limit, end_token=room_end_token, ), deferred_room_state, @@ -360,15 +349,11 @@ class InitialSyncHandler: member_event_id ) - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - leave_position = await self.store.get_position_for_event(member_event_id) stream_token = leave_position.to_room_stream_token() messages, token = await self.store.get_recent_events_for_room( - room_id, limit=limit, end_token=stream_token + room_id, limit=pagin_config.limit, end_token=stream_token ) messages = await filter_events_for_client( @@ -420,10 +405,6 @@ class InitialSyncHandler: now_token = self.hs.get_event_sources().get_current_token() - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - room_members = [ m for m in current_state.values() @@ -467,7 +448,7 @@ class InitialSyncHandler: run_in_background( self.store.get_recent_events_for_room, room_id, - limit=limit, + limit=pagin_config.limit, end_token=now_token.room_key, ), ), diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 1f83bab836..a4ca9cb8b4 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -458,11 +458,6 @@ class PaginationHandler: # `/messages` should still works with live tokens when manually provided. assert from_token.room_key.topological is not None - if pagin_config.limit is None: - # This shouldn't happen as we've set a default limit before this - # gets called. - raise Exception("limit not set") - room_token = from_token.room_key async with self.pagination_lock.read(room_id): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 4e575ffbaa..2670e561d7 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1596,7 +1596,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): self, user: UserID, from_key: Optional[int], - limit: Optional[int] = None, + # Having a default limit doesn't match the EventSource API, but some + # callers do not provide it. It is unused in this class. + limit: int = 0, room_ids: Optional[Collection[str]] = None, is_guest: bool = False, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 4a7ec9e426..ac01582442 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -257,7 +257,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 1fdd7a10bc..0a0c6d938e 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -116,9 +116,6 @@ class RelationsHandler: if event is None: raise SynapseError(404, "Unknown parent event.") - # TODO Update pagination config to not allow None limits. - assert pagin_config.limit is not None - # Note that ignored users are not passed into get_relations_for_event # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 57ab05ad25..4e1aacb408 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1646,7 +1646,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): self, user: UserID, from_key: RoomStreamToken, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index f953691669..a0ea719430 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -513,7 +513,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 916f5230f1..782e7d14e8 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -50,7 +50,9 @@ class EventStreamRestServlet(RestServlet): raise SynapseError(400, "Guest users must specify room_id param") room_id = parse_string(request, "room_id") - pagin_config = await PaginationConfig.from_request(self.store, request) + pagin_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if b"timeout" in args: try: diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py index cfadcb8e50..9b1bb8b521 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -39,7 +39,9 @@ class InitialSyncRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) args: Dict[bytes, List[bytes]] = request.args # type: ignore as_client_event = b"raw" not in args - pagination_config = await PaginationConfig.from_request(self.store, request) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index b6dedbed04..01e5079963 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -729,7 +729,9 @@ class RoomInitialSyncRestServlet(RestServlet): self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = await PaginationConfig.from_request(self.store, request) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index ffeb2b3683..5baffbfe55 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1200,8 +1200,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - assert int(limit) >= 0 - # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py index 806b671305..2dcd43d0a2 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py @@ -27,7 +27,7 @@ class EventSource(Generic[K, R]): self, user: UserID, from_key: K, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/streams/config.py b/synapse/streams/config.py index f6f7bf3d8b..6df2de919c 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -35,14 +35,14 @@ class PaginationConfig: from_token: Optional[StreamToken] to_token: Optional[StreamToken] direction: str - limit: Optional[int] + limit: int @classmethod async def from_request( cls, store: "DataStore", request: SynapseRequest, - default_limit: Optional[int] = None, + default_limit: int, default_dir: str = "f", ) -> "PaginationConfig": direction = parse_string( @@ -69,12 +69,10 @@ class PaginationConfig: raise SynapseError(400, "'to' parameter is invalid") limit = parse_integer(request, "limit", default=default_limit) + if limit < 0: + raise SynapseError(400, "Limit must be 0 or above") - if limit: - if limit < 0: - raise SynapseError(400, "Limit must be 0 or above") - - limit = min(int(limit), MAX_LIMIT) + limit = min(limit, MAX_LIMIT) try: return PaginationConfig(from_tok, to_tok, direction, limit) diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index 61b66d7685..fdc433a8b5 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -59,7 +59,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.event_source.get_new_events( user=UserID.from_string(self.user_id), from_key=0, - limit=None, + # Limit is unused. + limit=0, room_ids=[self.room_id], is_guest=False, ) -- cgit 1.4.1 From 4283bd1cf9c3da2157c3642a7c4f105e9fac2636 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Oct 2022 11:32:11 -0400 Subject: Support filtering the /messages API by relation type (MSC3874). (#14148) Gated behind an experimental configuration flag. --- changelog.d/14148.feature | 1 + synapse/api/filtering.py | 27 +++++- synapse/config/experimental.py | 3 + synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/stream.py | 29 ++++++- tests/api/test_filtering.py | 63 +++++++++++++- tests/rest/client/test_relations.py | 1 - tests/rest/client/test_rooms.py | 145 ++----------------------------- tests/storage/test_stream.py | 118 ++++++++++++++++++------- 9 files changed, 212 insertions(+), 177 deletions(-) create mode 100644 changelog.d/14148.feature (limited to 'tests/rest') diff --git a/changelog.d/14148.feature b/changelog.d/14148.feature new file mode 100644 index 0000000000..951d0cac80 --- /dev/null +++ b/changelog.d/14148.feature @@ -0,0 +1 @@ +Experimental support for [MSC3874](https://github.com/matrix-org/matrix-spec-proposals/pull/3874). diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cc31cf8cc7..26be377d03 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -36,7 +36,7 @@ from jsonschema import FormatChecker from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.types import JsonDict, RoomID, UserID if TYPE_CHECKING: @@ -53,6 +53,12 @@ FILTER_SCHEMA = { # check types are valid event types "types": {"type": "array", "items": {"type": "string"}}, "not_types": {"type": "array", "items": {"type": "string"}}, + # MSC3874, filtering /messages. + "org.matrix.msc3874.rel_types": {"type": "array", "items": {"type": "string"}}, + "org.matrix.msc3874.not_rel_types": { + "type": "array", + "items": {"type": "string"}, + }, }, } @@ -334,8 +340,15 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - self.related_by_senders = self.filter_json.get("related_by_senders", None) - self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + self.related_by_senders = filter_json.get("related_by_senders", None) + self.related_by_rel_types = filter_json.get("related_by_rel_types", None) + + # For compatibility with _check_fields. + self.rel_types = None + self.not_rel_types = [] + if hs.config.experimental.msc3874_enabled: + self.rel_types = filter_json.get("org.matrix.msc3874.rel_types", None) + self.not_rel_types = filter_json.get("org.matrix.msc3874.not_rel_types", []) def filters_all_types(self) -> bool: return "*" in self.not_types @@ -386,11 +399,19 @@ class Filter: # check if there is a string url field in the content for filtering purposes labels = content.get(EventContentFields.LABELS, []) + # Check if the event has a relation. + rel_type = None + if isinstance(event, EventBase): + relation = relation_from_event(event) + if relation: + rel_type = relation.rel_type + field_matchers = { "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, "types": lambda v: _matches_wildcard(ev_type, v), "labels": lambda v: v in labels, + "rel_types": lambda v: rel_type == v, } result = self._check_fields(field_matchers) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f44655516e..f9a49451d8 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -117,3 +117,6 @@ class ExperimentalConfig(Config): self.msc3882_token_timeout = self.parse_duration( experimental.get("msc3882_token_timeout", "5m") ) + + # MSC3874: Filtering /messages with rel_types / not_rel_types. + self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 4e1fd2bbe7..4b87ee978a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -114,6 +114,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3882": self.config.experimental.msc3882_enabled, # Adds support for remotely enabling/disabling pushers, as per MSC3881 "org.matrix.msc3881": self.config.experimental.msc3881_enabled, + # Adds support for filtering /messages by event relation. + "org.matrix.msc3874": self.config.experimental.msc3874_enabled, }, }, ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 5baffbfe55..09ce855aa8 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -357,6 +357,24 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: ) args.extend(event_filter.related_by_rel_types) + if event_filter.rel_types: + clauses.append( + "(%s)" + % " OR ".join( + "event_relation.relation_type = ?" for _ in event_filter.rel_types + ) + ) + args.extend(event_filter.rel_types) + + if event_filter.not_rel_types: + clauses.append( + "((%s) OR event_relation.relation_type IS NULL)" + % " AND ".join( + "event_relation.relation_type != ?" for _ in event_filter.not_rel_types + ) + ) + args.extend(event_filter.not_rel_types) + return " AND ".join(clauses), args @@ -1278,8 +1296,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # Multiple labels could cause the same event to appear multiple times. needs_distinct = True - # If there is a filter on relation_senders and relation_types join to the - # relations table. + # If there is a relation_senders and relation_types filter join to the + # relations table to get events related to the current event. if event_filter and ( event_filter.related_by_senders or event_filter.related_by_rel_types ): @@ -1294,6 +1312,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ + # If there is a not_rel_types filter join to the relations table to get + # the event's relation information. + if event_filter and (event_filter.rel_types or event_filter.not_rel_types): + join_clause += """ + LEFT JOIN event_relations AS event_relation USING (event_id) + """ + if needs_distinct: select_keywords += " DISTINCT" diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index a269c477fb..a82c4eed86 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -35,6 +35,8 @@ def MockEvent(**kwargs): kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" + if "content" not in kwargs: + kwargs["content"] = {} return make_event_from_dict(kwargs) @@ -357,6 +359,66 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertTrue(Filter(self.hs, definition)._check(event)) + @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) + def test_filter_rel_type(self): + definition = {"org.matrix.msc3874.rel_types": ["m.thread"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + + @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) + def test_filter_not_rel_type(self): + definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} filter_id = self.get_success( @@ -456,7 +518,6 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertEqual(filtered_room_ids, ["!allowed:example.com"]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_filter_relations(self): events = [ # An event without a relation. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f5c1070b2c..ddf315b894 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1677,7 +1677,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_parent_thread(self) -> None: """ Test that thread replies are still available when the root event is redacted. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 3612ebe7b9..71b1637be8 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -35,7 +35,6 @@ from synapse.api.constants import ( EventTypes, Membership, PublicRoomsFilterFields, - RelationTypes, RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException @@ -50,6 +49,7 @@ from synapse.util.stringutils import random_string from tests import unittest from tests.http.server._base import make_request_with_cancellation_test +from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -2915,149 +2915,20 @@ class LabelsTestCase(unittest.HomeserverTestCase): return event_id -class RelationsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def default_config(self) -> Dict[str, Any]: - config = super().default_config() - config["experimental_features"] = {"msc3440_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.user_id = self.register_user("test", "test") - self.tok = self.login("test", "test") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - self.second_user_id = self.register_user("second", "test") - self.second_tok = self.login("second", "test") - self.helper.join( - room=self.room_id, user=self.second_user_id, tok=self.second_tok - ) - - self.third_user_id = self.register_user("third", "test") - self.third_tok = self.login("third", "test") - self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) - - # An initial event with a relation from second user. - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "Message 1"}, - tok=self.tok, - ) - self.event_id_1 = res["event_id"] - self.helper.send_event( - room_id=self.room_id, - type="m.reaction", - content={ - "m.relates_to": { - "rel_type": RelationTypes.ANNOTATION, - "event_id": self.event_id_1, - "key": "👍", - } - }, - tok=self.second_tok, - ) - - # Another event with a relation from third user. - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "Message 2"}, - tok=self.tok, - ) - self.event_id_2 = res["event_id"] - self.helper.send_event( - room_id=self.room_id, - type="m.reaction", - content={ - "m.relates_to": { - "rel_type": RelationTypes.REFERENCE, - "event_id": self.event_id_2, - } - }, - tok=self.third_tok, - ) - - # An event with no relations. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "No relations"}, - tok=self.tok, - ) - - def _filter_messages(self, filter: JsonDict) -> List[JsonDict]: +class RelationsTestCase(PaginationTestCase): + def _filter_messages(self, filter: JsonDict) -> List[str]: """Make a request to /messages with a filter, returns the chunk of events.""" + from_token = self.get_success( + self.from_token.to_string(self.hs.get_datastores().main) + ) channel = self.make_request( "GET", - "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), + f"/rooms/{self.room_id}/messages?filter={json.dumps(filter)}&dir=f&from={from_token}", access_token=self.tok, ) self.assertEqual(channel.code, HTTPStatus.OK, channel.result) - return channel.json_body["chunk"] - - def test_filter_relation_senders(self) -> None: - # Messages which second user reacted to. - filter = {"related_by_senders": [self.second_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) - - # Messages which third user reacted to. - filter = {"related_by_senders": [self.third_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_2) - - # Messages which either user reacted to. - filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] - ) - - def test_filter_relation_type(self) -> None: - # Messages which have annotations. - filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) - - # Messages which have references. - filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_2) - - # Messages which have either annotations or references. - filter = { - "related_by_rel_types": [ - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - ] - } - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] - ) - - def test_filter_relation_senders_and_type(self) -> None: - # Messages which second user reacted to. - filter = { - "related_by_senders": [self.second_user_id], - "related_by_rel_types": [RelationTypes.ANNOTATION], - } - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) + return [ev["event_id"] for ev in channel.json_body["chunk"]] class ContextTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 78663a53fe..34fa810cf6 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -16,7 +16,6 @@ from typing import List from synapse.api.constants import EventTypes, RelationTypes from synapse.api.filtering import Filter -from synapse.events import EventBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.types import JsonDict @@ -40,7 +39,7 @@ class PaginationTestCase(HomeserverTestCase): def default_config(self): config = super().default_config() - config["experimental_features"] = {"msc3440_enabled": True} + config["experimental_features"] = {"msc3874_enabled": True} return config def prepare(self, reactor, clock, homeserver): @@ -58,6 +57,11 @@ class PaginationTestCase(HomeserverTestCase): self.third_tok = self.login("third", "test") self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) + # Store a token which is after all the room creation events. + self.from_token = self.get_success( + self.hs.get_event_sources().get_current_token_for_pagination(self.room_id) + ) + # An initial event with a relation from second user. res = self.helper.send_event( room_id=self.room_id, @@ -66,7 +70,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.tok, ) self.event_id_1 = res["event_id"] - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type="m.reaction", content={ @@ -78,6 +82,7 @@ class PaginationTestCase(HomeserverTestCase): }, tok=self.second_tok, ) + self.event_id_annotation = res["event_id"] # Another event with a relation from third user. res = self.helper.send_event( @@ -87,7 +92,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.tok, ) self.event_id_2 = res["event_id"] - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type="m.reaction", content={ @@ -98,68 +103,59 @@ class PaginationTestCase(HomeserverTestCase): }, tok=self.third_tok, ) + self.event_id_reference = res["event_id"] # An event with no relations. - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type=EventTypes.Message, content={"msgtype": "m.text", "body": "No relations"}, tok=self.tok, ) + self.event_id_none = res["event_id"] - def _filter_messages(self, filter: JsonDict) -> List[EventBase]: + def _filter_messages(self, filter: JsonDict) -> List[str]: """Make a request to /messages with a filter, returns the chunk of events.""" - from_token = self.get_success( - self.hs.get_event_sources().get_current_token_for_pagination(self.room_id) - ) - events, next_key = self.get_success( self.hs.get_datastores().main.paginate_room_events( room_id=self.room_id, - from_key=from_token.room_key, + from_key=self.from_token.room_key, to_key=None, - direction="b", + direction="f", limit=10, event_filter=Filter(self.hs, filter), ) ) - return events + return [ev.event_id for ev in events] def test_filter_relation_senders(self): # Messages which second user reacted to. filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) # Messages which third user reacted to. filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_2) + self.assertEqual(chunk, [self.event_id_2]) # Messages which either user reacted to. filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] - ) + self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) def test_filter_relation_type(self): # Messages which have annotations. filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) # Messages which have references. filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_2) + self.assertEqual(chunk, [self.event_id_2]) # Messages which have either annotations or references. filter = { @@ -169,10 +165,7 @@ class PaginationTestCase(HomeserverTestCase): ] } chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] - ) + self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) def test_filter_relation_senders_and_type(self): # Messages which second user reacted to. @@ -181,8 +174,7 @@ class PaginationTestCase(HomeserverTestCase): "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) def test_duplicate_relation(self): """An event should only be returned once if there are multiple relations to it.""" @@ -201,5 +193,65 @@ class PaginationTestCase(HomeserverTestCase): filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) + + def test_filter_rel_types(self) -> None: + # Messages which are annotations. + filter = {"org.matrix.msc3874.rel_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_annotation]) + + # Messages which are references. + filter = {"org.matrix.msc3874.rel_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_reference]) + + # Messages which are either annotations or references. + filter = { + "org.matrix.msc3874.rel_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertCountEqual( + chunk, + [self.event_id_annotation, self.event_id_reference], + ) + + def test_filter_not_rel_types(self) -> None: + # Messages which are not annotations. + filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual( + chunk, + [ + self.event_id_1, + self.event_id_2, + self.event_id_reference, + self.event_id_none, + ], + ) + + # Messages which are not references. + filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual( + chunk, + [ + self.event_id_1, + self.event_id_annotation, + self.event_id_2, + self.event_id_none, + ], + ) + + # Messages which are neither annotations or references. + filter = { + "org.matrix.msc3874.not_rel_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_1, self.event_id_2, self.event_id_none]) -- cgit 1.4.1 From 4eaf3eb840b8cfa78d970216c74fc128495f08a5 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Tue, 18 Oct 2022 16:52:25 +0100 Subject: Implementation of HTTP 307 response for MSC3886 POST endpoint (#14018) Co-authored-by: reivilibre Co-authored-by: Andrew Morgan --- changelog.d/14018.feature | 1 + synapse/config/experimental.py | 7 +- synapse/config/server.py | 4 ++ synapse/handlers/sso.py | 2 +- synapse/http/server.py | 48 ++++++++++--- synapse/http/site.py | 3 + synapse/rest/__init__.py | 2 + synapse/rest/client/rendezvous.py | 74 +++++++++++++++++++ synapse/rest/client/versions.py | 3 + synapse/rest/key/v2/local_key_resource.py | 4 +- synapse/rest/synapse/client/new_user_consent.py | 3 +- synapse/rest/well_known.py | 3 +- tests/logging/test_terse_json.py | 1 + tests/rest/client/test_rendezvous.py | 45 ++++++++++++ tests/server.py | 8 ++- tests/test_server.py | 94 ++++++++++++++++++------- 16 files changed, 257 insertions(+), 45 deletions(-) create mode 100644 changelog.d/14018.feature create mode 100644 synapse/rest/client/rendezvous.py create mode 100644 tests/rest/client/test_rendezvous.py (limited to 'tests/rest') diff --git a/changelog.d/14018.feature b/changelog.d/14018.feature new file mode 100644 index 0000000000..c8454607eb --- /dev/null +++ b/changelog.d/14018.feature @@ -0,0 +1 @@ +Support for redirecting to an implementation of a [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) HTTP rendezvous service. \ No newline at end of file diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f9a49451d8..4009add01d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import attr @@ -120,3 +120,8 @@ class ExperimentalConfig(Config): # MSC3874: Filtering /messages with rel_types / not_rel_types. self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) + + # MSC3886: Simple client rendezvous capability + self.msc3886_endpoint: Optional[str] = experimental.get( + "msc3886_endpoint", None + ) diff --git a/synapse/config/server.py b/synapse/config/server.py index f2353ce5fb..ec46ca63ad 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -207,6 +207,9 @@ class HttpListenerConfig: additional_resources: Dict[str, dict] = attr.Factory(dict) tag: Optional[str] = None request_id_header: Optional[str] = None + # If true, the listener will return CORS response headers compatible with MSC3886: + # https://github.com/matrix-org/matrix-spec-proposals/pull/3886 + experimental_cors_msc3886: bool = False @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -935,6 +938,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig: additional_resources=listener.get("additional_resources", {}), tag=listener.get("tag"), request_id_header=listener.get("request_id_header"), + experimental_cors_msc3886=listener.get("experimental_cors_msc3886", False), ) return ListenerConfig(port, bind_addresses, listener_type, tls, http_config) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e035677b8a..5943f08e91 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -874,7 +874,7 @@ class SsoHandler: ) async def handle_terms_accepted( - self, request: Request, session_id: str, terms_version: str + self, request: SynapseRequest, session_id: str, terms_version: str ) -> None: """Handle a request to the new-user 'consent' endpoint diff --git a/synapse/http/server.py b/synapse/http/server.py index bcbfac2c9f..b26e34bceb 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -19,6 +19,7 @@ import logging import types import urllib from http import HTTPStatus +from http.client import FOUND from inspect import isawaitable from typing import ( TYPE_CHECKING, @@ -339,7 +340,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): return callback_return - _unrecognised_request_handler(request) + return _unrecognised_request_handler(request) @abc.abstractmethod def _send_response( @@ -598,7 +599,7 @@ class RootRedirect(resource.Resource): class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" - def render_OPTIONS(self, request: Request) -> bytes: + def render_OPTIONS(self, request: SynapseRequest) -> bytes: request.setResponseCode(204) request.setHeader(b"Content-Length", b"0") @@ -763,7 +764,7 @@ def respond_with_json( def respond_with_json_bytes( - request: Request, + request: SynapseRequest, code: int, json_bytes: bytes, send_cors: bool = False, @@ -859,7 +860,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: _ByteProducer(request, bytes_generator) -def set_cors_headers(request: Request) -> None: +def set_cors_headers(request: SynapseRequest) -> None: """Set the CORS headers so that javascript running in a web browsers can use this API @@ -870,10 +871,20 @@ def set_cors_headers(request: Request) -> None: request.setHeader( b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS" ) - request.setHeader( - b"Access-Control-Allow-Headers", - b"X-Requested-With, Content-Type, Authorization, Date", - ) + if request.experimental_cors_msc3886: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match", + ) + request.setHeader( + b"Access-Control-Expose-Headers", + b"ETag, Location, X-Max-Bytes", + ) + else: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date", + ) def set_corp_headers(request: Request) -> None: @@ -942,10 +953,25 @@ def set_clickjacking_protection_headers(request: Request) -> None: request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") -def respond_with_redirect(request: Request, url: bytes) -> None: - """Write a 302 response to the request, if it is still alive.""" +def respond_with_redirect( + request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False +) -> None: + """ + Write a 302 (or other specified status code) response to the request, if it is still alive. + + Args: + request: The http request to respond to. + url: The URL to redirect to. + statusCode: The HTTP status code to use for the redirect (defaults to 302). + cors: Whether to set CORS headers on the response. + """ logger.debug("Redirect to %s", url.decode("utf-8")) - request.redirect(url) + + if cors: + set_cors_headers(request) + + request.setResponseCode(statusCode) + request.setHeader(b"location", url) finish_request(request) diff --git a/synapse/http/site.py b/synapse/http/site.py index 55a6afce35..3dbd541fed 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -82,6 +82,7 @@ class SynapseRequest(Request): self.reactor = site.reactor self._channel = channel # this is used by the tests self.start_time = 0.0 + self.experimental_cors_msc3886 = site.experimental_cors_msc3886 # The requester, if authenticated. For federation requests this is the # server name, for client requests this is the Requester object. @@ -622,6 +623,8 @@ class SynapseSite(Site): request_id_header = config.http_options.request_id_header + self.experimental_cors_msc3886 = config.http_options.experimental_cors_msc3886 + def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 9a2ab99ede..28542cd774 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.client import ( receipts, register, relations, + rendezvous, report_event, room, room_batch, @@ -132,3 +133,4 @@ class ClientRestResource(JsonResource): # unstable mutual_rooms.register_servlets(hs, client_resource) login_token_request.register_servlets(hs, client_resource) + rendezvous.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py new file mode 100644 index 0000000000..89176b1ffa --- /dev/null +++ b/synapse/rest/client/rendezvous.py @@ -0,0 +1,74 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from http.client import TEMPORARY_REDIRECT +from typing import TYPE_CHECKING, Optional + +from synapse.http.server import HttpServer, respond_with_redirect +from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RendezvousServlet(RestServlet): + """ + This is a placeholder implementation of [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) + simple client rendezvous capability that is used by the "Sign in with QR" functionality. + + This implementation only serves as a 307 redirect to a configured server rather than being a full implementation. + + A module that implements the full functionality is available at: https://pypi.org/project/matrix-http-rendezvous-synapse/. + + Request: + + POST /rendezvous HTTP/1.1 + Content-Type: ... + + ... + + Response: + + HTTP/1.1 307 + Location: + """ + + PATTERNS = client_patterns( + "/org.matrix.msc3886/rendezvous$", releases=[], v1=False, unstable=True + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + redirection_target: Optional[str] = hs.config.experimental.msc3886_endpoint + assert ( + redirection_target is not None + ), "Servlet is only registered if there is a redirection target" + self.endpoint = redirection_target.encode("utf-8") + + async def on_POST(self, request: SynapseRequest) -> None: + respond_with_redirect( + request, self.endpoint, statusCode=TEMPORARY_REDIRECT, cors=True + ) + + # PUT, GET and DELETE are not implemented as they should be fulfilled by the redirect target. + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3886_endpoint is not None: + RendezvousServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 4b87ee978a..9b1b72c68a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -116,6 +116,9 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3881": self.config.experimental.msc3881_enabled, # Adds support for filtering /messages by event relation. "org.matrix.msc3874": self.config.experimental.msc3874_enabled, + # Adds support for simple HTTP rendezvous as per MSC3886 + "org.matrix.msc3886": self.config.experimental.msc3886_endpoint + is not None, }, }, ) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 0c9f042c84..095993415c 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -20,9 +20,9 @@ from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 from twisted.web.resource import Resource -from twisted.web.server import Request from synapse.http.server import respond_with_json_bytes +from synapse.http.site import SynapseRequest from synapse.types import JsonDict if TYPE_CHECKING: @@ -99,7 +99,7 @@ class LocalKey(Resource): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: Request) -> Optional[int]: + def render_GET(self, request: SynapseRequest) -> Optional[int]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index 1c1c7b3613..22784157e6 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -20,6 +20,7 @@ from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.http.server import DirectServeHtmlResource, respond_with_html from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest from synapse.types import UserID from synapse.util.templates import build_jinja_env @@ -88,7 +89,7 @@ class NewUserConsentResource(DirectServeHtmlResource): html = template.render(template_params) respond_with_html(request, 200, html) - async def _async_render_POST(self, request: Request) -> None: + async def _async_render_POST(self, request: SynapseRequest) -> None: try: session_id = get_username_mapping_session_cookie_from_request(request) except SynapseError as e: diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 6f7ac54c65..e2174fdfea 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -18,6 +18,7 @@ from twisted.web.resource import Resource from twisted.web.server import Request from synapse.http.server import set_cors_headers +from synapse.http.site import SynapseRequest from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.stringutils import parse_server_name @@ -63,7 +64,7 @@ class ClientWellKnownResource(Resource): Resource.__init__(self) self._well_known_builder = WellKnownBuilder(hs) - def render_GET(self, request: Request) -> bytes: + def render_GET(self, request: SynapseRequest) -> bytes: set_cors_headers(request) r = self._well_known_builder.get_well_known() if not r: diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index 96f399b7ab..0b0d8737c1 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -153,6 +153,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): site.site_tag = "test-site" site.server_version_string = "Server v1" site.reactor = Mock() + site.experimental_cors_msc3886 = False request = SynapseRequest(FakeChannel(site, None), site) # Call requestReceived to finish instantiating the object. request.content = BytesIO() diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py new file mode 100644 index 0000000000..ad00a476e1 --- /dev/null +++ b/tests/rest/client/test_rendezvous.py @@ -0,0 +1,45 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest.client import rendezvous +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config + +endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous" + + +class RendezvousServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + rendezvous.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.hs = self.setup_test_homeserver() + return self.hs + + def test_disabled(self) -> None: + channel = self.make_request("POST", endpoint, {}, access_token=None) + self.assertEqual(channel.code, 400) + + @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}}) + def test_redirect(self) -> None: + channel = self.make_request("POST", endpoint, {}, access_token=None) + self.assertEqual(channel.code, 307) + self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"]) diff --git a/tests/server.py b/tests/server.py index c447d5e4c4..8b1d186219 100644 --- a/tests/server.py +++ b/tests/server.py @@ -266,7 +266,12 @@ class FakeSite: site_tag = "test" access_logger = logging.getLogger("synapse.access.http.fake") - def __init__(self, resource: IResource, reactor: IReactorTime): + def __init__( + self, + resource: IResource, + reactor: IReactorTime, + experimental_cors_msc3886: bool = False, + ): """ Args: @@ -274,6 +279,7 @@ class FakeSite: """ self._resource = resource self.reactor = reactor + self.experimental_cors_msc3886 = experimental_cors_msc3886 def getResourceFor(self, request): return self._resource diff --git a/tests/test_server.py b/tests/test_server.py index 7c66448245..2d9a0257d4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -222,13 +222,22 @@ class OptionsResourceTests(unittest.TestCase): self.resource = OptionsResource() self.resource.putChild(b"res", DummyResource()) - def _make_request(self, method: bytes, path: bytes) -> FakeChannel: + def _make_request( + self, method: bytes, path: bytes, experimental_cors_msc3886: bool = False + ) -> FakeChannel: """Create a request from the method/path and return a channel with the response.""" # Create a site and query for the resource. site = SynapseSite( "test", "site_tag", - parse_listener_def(0, {"type": "http", "port": 0}), + parse_listener_def( + 0, + { + "type": "http", + "port": 0, + "experimental_cors_msc3886": experimental_cors_msc3886, + }, + ), self.resource, "1.0", max_request_body_size=4096, @@ -239,25 +248,58 @@ class OptionsResourceTests(unittest.TestCase): channel = make_request(self.reactor, site, method, path, shorthand=False) return channel + def _check_cors_standard_headers(self, channel: FakeChannel) -> None: + # Ensure the correct CORS headers have been added + # as per https://spec.matrix.org/v1.4/client-server-api/#web-browser-clients + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"), + [b"*"], + "has correct CORS Origin header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"), + [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec + "has correct CORS Methods header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"), + [b"X-Requested-With, Content-Type, Authorization, Date"], + "has correct CORS Headers header", + ) + + def _check_cors_msc3886_headers(self, channel: FakeChannel) -> None: + # Ensure the correct CORS headers have been added + # as per https://github.com/matrix-org/matrix-spec-proposals/blob/hughns/simple-rendezvous-capability/proposals/3886-simple-rendezvous-capability.md#cors + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"), + [b"*"], + "has correct CORS Origin header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"), + [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec + "has correct CORS Methods header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"), + [ + b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match" + ], + "has correct CORS Headers header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Expose-Headers"), + [b"ETag, Location, X-Max-Bytes"], + "has correct CORS Expose Headers header", + ) + def test_unknown_options_request(self) -> None: """An OPTIONS requests to an unknown URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/foo/") self.assertEqual(channel.code, 204) self.assertNotIn("body", channel.result) - # Ensure the correct CORS headers have been added - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Origin"), - "has CORS Origin header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Methods"), - "has CORS Methods header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Headers"), - "has CORS Headers header", - ) + self._check_cors_standard_headers(channel) def test_known_options_request(self) -> None: """An OPTIONS requests to an known URL still returns 204 No Content.""" @@ -265,19 +307,17 @@ class OptionsResourceTests(unittest.TestCase): self.assertEqual(channel.code, 204) self.assertNotIn("body", channel.result) - # Ensure the correct CORS headers have been added - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Origin"), - "has CORS Origin header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Methods"), - "has CORS Methods header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Headers"), - "has CORS Headers header", + self._check_cors_standard_headers(channel) + + def test_known_options_request_msc3886(self) -> None: + """An OPTIONS requests to an known URL still returns 204 No Content.""" + channel = self._make_request( + b"OPTIONS", b"/res/", experimental_cors_msc3886=True ) + self.assertEqual(channel.code, 204) + self.assertNotIn("body", channel.result) + + self._check_cors_msc3886_headers(channel) def test_unknown_request(self) -> None: """A non-OPTIONS request to an unknown URL should 404.""" -- cgit 1.4.1 From fa8616e65c82367712a7b75c62682a89541b6330 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 18 Oct 2022 19:46:25 -0500 Subject: Fix MSC3030 `/timestamp_to_event` returning `outliers` that it has no idea whether are near a gap or not (#14215) Fix MSC3030 `/timestamp_to_event` endpoint returning `outliers` that it has no idea whether are near a gap or not (and therefore unable to determine whether it's actually the closest event). The reason Synapse doesn't know whether an `outlier` is next to a gap is because our gap checks rely on entries in the `event_edges`, `event_forward_extremeties`, and `event_backward_extremities` tables which is [not the case for `outliers`](https://github.com/matrix-org/synapse/blob/2c63cdcc3f1aa4625e947de3c23e0a8133c61286/docs/development/room-dag-concepts.md#outliers). Also fixes MSC3030 Complement `can_paginate_after_getting_remote_event_from_timestamp_to_event_endpoint` test flake. Although this acted flakey in Complement, if `sync_partial_state` raced and beat us before `/timestamp_to_event`, then even if we retried the failing `/context` request it wouldn't work until we made this Synapse change. With this PR, Synapse will never return an `outlier` event so that test will always go and ask over federation. Fix https://github.com/matrix-org/synapse/issues/13944 ### Why did this fail before? Why was it flakey? Sleuthing the server logs on the [CI failure](https://github.com/matrix-org/synapse/actions/runs/3149623842/jobs/5121449357#step:5:5805), it looks like `hs2:/timestamp_to_event` found `$NP6-oU7mIFVyhtKfGvfrEQX949hQX-T-gvuauG6eurU` as an `outlier` event locally. Then when we went and asked for it via `/context`, since it's an `outlier`, it was filtered out of the results -> `You don't have permission to access that event.` This is reproducible when `sync_partial_state` races and persists `$NP6-oU7mIFVyhtKfGvfrEQX949hQX-T-gvuauG6eurU` as an `outlier` before we evaluate `get_event_for_timestamp(...)`. To consistently reproduce locally, just add a delay at the [start of `get_event_for_timestamp(...)`](https://github.com/matrix-org/synapse/blob/cb20b885cb4bd1648581dd043a184d86fc8c7a00/synapse/handlers/room.py#L1470-L1496) so it always runs after `sync_partial_state` completes. ```py from twisted.internet import task as twisted_task d = twisted_task.deferLater(self.hs.get_reactor(), 3.5) await d ``` In a run where it passes, on `hs2`, `get_event_for_timestamp(...)` finds a different event locally which is next to a gap and we request from a closer one from `hs1` which gets backfilled. And since the backfilled event is not an `outlier`, it's returned as expected during `/context`. With this PR, Synapse will never return an `outlier` event so that test will always go and ask over federation. --- changelog.d/14215.bugfix | 1 + synapse/storage/databases/main/events_worker.py | 59 ++++++++++++++-------- tests/rest/client/test_rooms.py | 65 +++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 21 deletions(-) create mode 100644 changelog.d/14215.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14215.bugfix b/changelog.d/14215.bugfix new file mode 100644 index 0000000000..31c109f534 --- /dev/null +++ b/changelog.d/14215.bugfix @@ -0,0 +1 @@ +Fix [MSC3030](https://github.com/matrix-org/matrix-spec-proposals/pull/3030) `/timestamp_to_event` endpoint returning potentially inaccurate closest events with `outliers` present. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 7bc7f2f33e..69fea452ad 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1971,12 +1971,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_backward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_backward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question has any of its prev_events listed as a # backward extremity, it's next to a gap. @@ -2026,12 +2031,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_forward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_edges` and `event_forward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question is a forward extremity, we will just # consider any potential forward gap as not a gap since it's one of @@ -2112,13 +2122,33 @@ class EventsWorkerStore(SQLBaseStore): The closest event_id otherwise None if we can't find any event in the given direction. """ + if direction == "b": + # Find closest event *before* a given timestamp. We use descending + # (which gives values largest to smallest) because we want the + # largest possible timestamp *before* the given timestamp. + comparison_operator = "<=" + order = "DESC" + else: + # Find closest event *after* a given timestamp. We use ascending + # (which gives values smallest to largest) because we want the + # closest possible timestamp *after* the given timestamp. + comparison_operator = ">=" + order = "ASC" - sql_template = """ + sql_template = f""" SELECT event_id FROM events LEFT JOIN rejections USING (event_id) WHERE - origin_server_ts %s ? - AND room_id = ? + room_id = ? + AND origin_server_ts {comparison_operator} ? + /** + * Make sure the event isn't an `outlier` because we have no way + * to later check whether it's next to a gap. `outliers` do not + * have entries in the `event_edges`, `event_forward_extremeties`, + * and `event_backward_extremities` tables to check against + * (used by `is_event_next_to_backward_gap` and `is_event_next_to_forward_gap`). + */ + AND NOT outlier /* Make sure event is not rejected */ AND rejections.event_id IS NULL /** @@ -2128,27 +2158,14 @@ class EventsWorkerStore(SQLBaseStore): * Finally, we can tie-break based on when it was received on the server * (`stream_ordering`). */ - ORDER BY origin_server_ts %s, depth %s, stream_ordering %s + ORDER BY origin_server_ts {order}, depth {order}, stream_ordering {order} LIMIT 1; """ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: - if direction == "b": - # Find closest event *before* a given timestamp. We use descending - # (which gives values largest to smallest) because we want the - # largest possible timestamp *before* the given timestamp. - comparison_operator = "<=" - order = "DESC" - else: - # Find closest event *after* a given timestamp. We use ascending - # (which gives values smallest to largest) because we want the - # closest possible timestamp *after* the given timestamp. - comparison_operator = ">=" - order = "ASC" - txn.execute( - sql_template % (comparison_operator, order, order, order), - (timestamp, room_id), + sql_template, + (room_id, timestamp), ) row = txn.fetchone() if row: diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 71b1637be8..716366eb90 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -39,6 +39,8 @@ from synapse.api.constants import ( ) from synapse.api.errors import Codes, HttpResponseException from synapse.appservice import ApplicationService +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, register, room, sync @@ -51,6 +53,7 @@ from tests import unittest from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable +from tests.test_utils.event_injection import create_event PATH_PREFIX = b"/_matrix/client/api/v1" @@ -3486,3 +3489,65 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400) self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM") + + +class TimestampLookupTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc3030_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self._storage_controllers = self.hs.get_storage_controllers() + + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + def _inject_outlier(self, room_id: str) -> EventBase: + event, _context = self.get_success( + create_event( + self.hs, + room_id=room_id, + type="m.test", + sender="@test_remote_user:remote", + ) + ) + + event.internal_metadata.outlier = True + self.get_success( + self._storage_controllers.persistence.persist_event( + event, EventContext.for_outlier(self._storage_controllers) + ) + ) + return event + + def test_no_outliers(self) -> None: + """ + Test to make sure `/timestamp_to_event` does not return `outlier` events. + We're unable to determine whether an `outlier` is next to a gap so we + don't know whether it's actually the closest event. Instead, let's just + ignore `outliers` with this endpoint. + + This test is really seeing that we choose the non-`outlier` event behind the + `outlier`. Since the gap checking logic considers the latest message in the room + as *not* next to a gap, asking over federation does not come into play here. + """ + room_id = self.helper.create_room_as(self.room_owner, tok=self.room_owner_tok) + + outlier_event = self._inject_outlier(room_id) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3030/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}", + access_token=self.room_owner_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # Make sure the outlier event is not returned + self.assertNotEqual(channel.json_body["event_id"], outlier_event.event_id) -- cgit 1.4.1 From 755bfeee3a1ac7077045ab9e5a994b6ca89afba3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Oct 2022 11:32:47 -0400 Subject: Use servlets for /key/ endpoints. (#14229) To fix the response for unknown endpoints under that prefix. See MSC3743. --- changelog.d/14229.misc | 1 + synapse/api/urls.py | 2 +- synapse/app/generic_worker.py | 20 +++----- synapse/app/homeserver.py | 26 ++++------ synapse/rest/key/v2/__init__.py | 19 ++++--- synapse/rest/key/v2/local_key_resource.py | 22 ++++---- synapse/rest/key/v2/remote_key_resource.py | 73 +++++++++++++++------------ tests/app/test_openid_listener.py | 2 +- tests/rest/key/v2/test_remote_key_resource.py | 4 +- 9 files changed, 86 insertions(+), 83 deletions(-) create mode 100644 changelog.d/14229.misc (limited to 'tests/rest') diff --git a/changelog.d/14229.misc b/changelog.d/14229.misc new file mode 100644 index 0000000000..b9cd9a34d5 --- /dev/null +++ b/changelog.d/14229.misc @@ -0,0 +1 @@ +Refactor `/key/` endpoints to use `RestServlet` classes. diff --git a/synapse/api/urls.py b/synapse/api/urls.py index bd49fa6a5f..a918579f50 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -28,7 +28,7 @@ FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1" FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2" FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable" STATIC_PREFIX = "/_matrix/static" -SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" +SERVER_KEY_PREFIX = "/_matrix/key" MEDIA_R0_PREFIX = "/_matrix/media/r0" MEDIA_V3_PREFIX = "/_matrix/media/v3" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index dc49840f73..2a9f039367 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -28,7 +28,7 @@ from synapse.api.urls import ( LEGACY_MEDIA_PREFIX, MEDIA_R0_PREFIX, MEDIA_V3_PREFIX, - SERVER_KEY_V2_PREFIX, + SERVER_KEY_PREFIX, ) from synapse.app import _base from synapse.app._base import ( @@ -89,7 +89,7 @@ from synapse.rest.client.register import ( RegistrationTokenValidityRestServlet, ) from synapse.rest.health import HealthResource -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer @@ -325,13 +325,13 @@ class GenericWorkerServer(HomeServer): presence.register_servlets(self, resource) - resources.update({CLIENT_API_PREFIX: resource}) + resources[CLIENT_API_PREFIX] = resource resources.update(build_synapse_client_resource_tree(self)) - resources.update({"/.well-known": well_known_resource(self)}) + resources["/.well-known"] = well_known_resource(self) elif name == "federation": - resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) + resources[FEDERATION_PREFIX] = TransportLayerServer(self) elif name == "media": if self.config.media.can_load_media_repo: media_repo = self.get_media_repository_resource() @@ -359,16 +359,12 @@ class GenericWorkerServer(HomeServer): # Only load the openid resource separately if federation resource # is not specified since federation resource includes openid # resource. - resources.update( - { - FEDERATION_PREFIX: TransportLayerServer( - self, servlet_groups=["openid"] - ) - } + resources[FEDERATION_PREFIX] = TransportLayerServer( + self, servlet_groups=["openid"] ) if name in ["keys", "federation"]: - resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) + resources[SERVER_KEY_PREFIX] = KeyResource(self) if name == "replication": resources[REPLICATION_PREFIX] = ReplicationRestResource(self) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 883f2fd2ec..de3f08876f 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -31,7 +31,7 @@ from synapse.api.urls import ( LEGACY_MEDIA_PREFIX, MEDIA_R0_PREFIX, MEDIA_V3_PREFIX, - SERVER_KEY_V2_PREFIX, + SERVER_KEY_PREFIX, STATIC_PREFIX, ) from synapse.app import _base @@ -60,7 +60,7 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer @@ -215,30 +215,22 @@ class SynapseHomeServer(HomeServer): consent_resource: Resource = ConsentResource(self) if compress: consent_resource = gz_wrap(consent_resource) - resources.update({"/_matrix/consent": consent_resource}) + resources["/_matrix/consent"] = consent_resource if name == "federation": federation_resource: Resource = TransportLayerServer(self) if compress: federation_resource = gz_wrap(federation_resource) - resources.update({FEDERATION_PREFIX: federation_resource}) + resources[FEDERATION_PREFIX] = federation_resource if name == "openid": - resources.update( - { - FEDERATION_PREFIX: TransportLayerServer( - self, servlet_groups=["openid"] - ) - } + resources[FEDERATION_PREFIX] = TransportLayerServer( + self, servlet_groups=["openid"] ) if name in ["static", "client"]: - resources.update( - { - STATIC_PREFIX: StaticResource( - os.path.join(os.path.dirname(synapse.__file__), "static") - ) - } + resources[STATIC_PREFIX] = StaticResource( + os.path.join(os.path.dirname(synapse.__file__), "static") ) if name in ["media", "federation", "client"]: @@ -257,7 +249,7 @@ class SynapseHomeServer(HomeServer): ) if name in ["keys", "federation"]: - resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) + resources[SERVER_KEY_PREFIX] = KeyResource(self) if name == "metrics" and self.config.metrics.enable_metrics: metrics_resource: Resource = MetricsResource(RegistryProxy) diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py index 7f8c1de1ff..26403facb8 100644 --- a/synapse/rest/key/v2/__init__.py +++ b/synapse/rest/key/v2/__init__.py @@ -14,17 +14,20 @@ from typing import TYPE_CHECKING -from twisted.web.resource import Resource - -from .local_key_resource import LocalKey -from .remote_key_resource import RemoteKey +from synapse.http.server import HttpServer, JsonResource +from synapse.rest.key.v2.local_key_resource import LocalKey +from synapse.rest.key.v2.remote_key_resource import RemoteKey if TYPE_CHECKING: from synapse.server import HomeServer -class KeyApiV2Resource(Resource): +class KeyResource(JsonResource): def __init__(self, hs: "HomeServer"): - Resource.__init__(self) - self.putChild(b"server", LocalKey(hs)) - self.putChild(b"query", RemoteKey(hs)) + super().__init__(hs, canonical_json=True) + self.register_servlets(self, hs) + + @staticmethod + def register_servlets(http_server: HttpServer, hs: "HomeServer") -> None: + LocalKey(hs).register(http_server) + RemoteKey(hs).register(http_server) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 095993415c..d03e728d42 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -13,16 +13,15 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional +import re +from typing import TYPE_CHECKING, Optional, Tuple -from canonicaljson import encode_canonical_json from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 -from twisted.web.resource import Resource +from twisted.web.server import Request -from synapse.http.server import respond_with_json_bytes -from synapse.http.site import SynapseRequest +from synapse.http.servlet import RestServlet from synapse.types import JsonDict if TYPE_CHECKING: @@ -31,7 +30,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LocalKey(Resource): +class LocalKey(RestServlet): """HTTP resource containing encoding the TLS X.509 certificate and NACL signature verification keys for this server:: @@ -61,18 +60,17 @@ class LocalKey(Resource): } """ - isLeaf = True + PATTERNS = (re.compile("^/_matrix/key/v2/server(/(?P[^/]*))?$"),) def __init__(self, hs: "HomeServer"): self.config = hs.config self.clock = hs.get_clock() self.update_response_body(self.clock.time_msec()) - Resource.__init__(self) def update_response_body(self, time_now_msec: int) -> None: refresh_interval = self.config.key.key_refresh_interval self.valid_until_ts = int(time_now_msec + refresh_interval) - self.response_body = encode_canonical_json(self.response_json_object()) + self.response_body = self.response_json_object() def response_json_object(self) -> JsonDict: verify_keys = {} @@ -99,9 +97,11 @@ class LocalKey(Resource): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: SynapseRequest) -> Optional[int]: + def on_GET( + self, request: Request, key_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: self.update_response_body(time_now) - return respond_with_json_bytes(request, 200, self.response_body) + return 200, self.response_body diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 7f8ad29566..19820886f5 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -13,15 +13,20 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Set +import re +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from signedjson.sign import sign_json -from synapse.api.errors import Codes, SynapseError +from twisted.web.server import Request + from synapse.crypto.keyring import ServerKeyFetcher -from synapse.http.server import DirectServeJsonResource, respond_with_json -from synapse.http.servlet import parse_integer, parse_json_object_from_request -from synapse.http.site import SynapseRequest +from synapse.http.server import HttpServer +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, +) from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import yieldable_gather_results @@ -32,7 +37,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class RemoteKey(DirectServeJsonResource): +class RemoteKey(RestServlet): """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported X.509 TLS certificate matches the one used in the HTTPS connection. Checks @@ -88,11 +93,7 @@ class RemoteKey(DirectServeJsonResource): } """ - isLeaf = True - def __init__(self, hs: "HomeServer"): - super().__init__() - self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -101,36 +102,48 @@ class RemoteKey(DirectServeJsonResource): ) self.config = hs.config - async def _async_render_GET(self, request: SynapseRequest) -> None: - assert request.postpath is not None - if len(request.postpath) == 1: - (server,) = request.postpath - query: dict = {server.decode("ascii"): {}} - elif len(request.postpath) == 2: - server, key_id = request.postpath + def register(self, http_server: HttpServer) -> None: + http_server.register_paths( + "GET", + ( + re.compile( + "^/_matrix/key/v2/query/(?P[^/]*)(/(?P[^/]*))?$" + ), + ), + self.on_GET, + self.__class__.__name__, + ) + http_server.register_paths( + "POST", + (re.compile("^/_matrix/key/v2/query$"),), + self.on_POST, + self.__class__.__name__, + ) + + async def on_GET( + self, request: Request, server: str, key_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: + if server and key_id: minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") arguments = {} if minimum_valid_until_ts is not None: arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}} + query = {server: {key_id: arguments}} else: - raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND) + query = {server: {}} - await self.query_keys(request, query, query_remote_on_cache_miss=True) + return 200, await self.query_keys(query, query_remote_on_cache_miss=True) - async def _async_render_POST(self, request: SynapseRequest) -> None: + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) query = content["server_keys"] - await self.query_keys(request, query, query_remote_on_cache_miss=True) + return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def query_keys( - self, - request: SynapseRequest, - query: JsonDict, - query_remote_on_cache_miss: bool = False, - ) -> None: + self, query: JsonDict, query_remote_on_cache_miss: bool = False + ) -> JsonDict: logger.info("Handling query for keys %r", query) store_queries = [] @@ -232,7 +245,7 @@ class RemoteKey(DirectServeJsonResource): for server_name, keys in cache_misses.items() ), ) - await self.query_keys(request, query, query_remote_on_cache_miss=False) + return await self.query_keys(query, query_remote_on_cache_miss=False) else: signed_keys = [] for key_json_raw in json_results: @@ -244,6 +257,4 @@ class RemoteKey(DirectServeJsonResource): signed_keys.append(key_json) - response = {"server_keys": signed_keys} - - respond_with_json(request, 200, response, canonical_json=True) + return {"server_keys": signed_keys} diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index c7dae58eb5..8d03da7f96 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -79,7 +79,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): self.assertEqual(channel.code, 401) -@patch("synapse.app.homeserver.KeyApiV2Resource", new=Mock()) +@patch("synapse.app.homeserver.KeyResource", new=Mock()) class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index ac0ac06b7e..7f1fba1086 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -26,7 +26,7 @@ from twisted.web.resource import NoResource, Resource from synapse.crypto.keyring import PerspectivesKeyFetcher from synapse.http.site import SynapseRequest -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.server import HomeServer from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict @@ -46,7 +46,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): def create_test_resource(self) -> Resource: return create_resource_tree( - {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() + {"/_matrix/key/v2": KeyResource(self.hs)}, root_resource=NoResource() ) def expect_outgoing_key_request( -- cgit 1.4.1 From 1433b5d5b64c3a6624e6e4ff4fef22127c49df86 Mon Sep 17 00:00:00 2001 From: Tadeusz Sośnierz Date: Fri, 21 Oct 2022 14:52:44 +0200 Subject: Show erasure status when listing users in the Admin API (#14205) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Show erasure status when listing users in the Admin API * Use USING when joining erased_users * Add changelog entry * Revert "Use USING when joining erased_users" This reverts commit 30bd2bf106415caadcfdbdd1b234ef2b106cc394. * Make the erased check work on postgres * Add a testcase for showing erased user status * Appease the style linter * Explicitly convert `erased` to bool to make SQLite consistent with Postgres This also adds us an easy way in to fix the other accidentally integered columns. * Move erasure status test to UsersListTestCase * Include user erased status when fetching user info via the admin API * Document the erase status in user_admin_api * Appease the linter and mypy * Signpost comments in tests Co-authored-by: Tadeusz Sośnierz Co-authored-by: David Robertson --- changelog.d/14205.feature | 1 + docs/admin_api/user_admin_api.md | 4 ++++ synapse/handlers/admin.py | 1 + synapse/storage/databases/main/__init__.py | 13 +++++++++-- tests/rest/admin/test_user.py | 35 +++++++++++++++++++++++++++++- 5 files changed, 51 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14205.feature (limited to 'tests/rest') diff --git a/changelog.d/14205.feature b/changelog.d/14205.feature new file mode 100644 index 0000000000..6692063352 --- /dev/null +++ b/changelog.d/14205.feature @@ -0,0 +1 @@ +Show erasure status when listing users in the Admin API. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 3625c7b6c5..c95d6c9b05 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -37,6 +37,7 @@ It returns a JSON body like the following: "is_guest": 0, "admin": 0, "deactivated": 0, + "erased": false, "shadow_banned": 0, "creation_ts": 1560432506, "appservice_id": null, @@ -167,6 +168,7 @@ A response body like the following is returned: "admin": 0, "user_type": null, "deactivated": 0, + "erased": false, "shadow_banned": 0, "displayname": "", "avatar_url": null, @@ -177,6 +179,7 @@ A response body like the following is returned: "admin": 1, "user_type": null, "deactivated": 0, + "erased": false, "shadow_banned": 0, "displayname": "", "avatar_url": "", @@ -247,6 +250,7 @@ The following fields are returned in the JSON response body: - `user_type` - string - Type of the user. Normal users are type `None`. This allows user type specific behaviour. There are also types `support` and `bot`. - `deactivated` - bool - Status if that user has been marked as deactivated. + - `erased` - bool - Status if that user has been marked as erased. - `shadow_banned` - bool - Status if that user has been marked as shadow banned. - `displayname` - string - The user's display name if they have set one. - `avatar_url` - string - The user's avatar URL if they have set one. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index f2989cc4a2..5bf8e86387 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -100,6 +100,7 @@ class AdminHandler: user_info_dict["avatar_url"] = profile.avatar_url user_info_dict["threepids"] = threepids user_info_dict["external_ids"] = external_ids + user_info_dict["erased"] = await self.store.is_user_erased(user.to_string()) return user_info_dict diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index a62b4abd4e..cfaedf5e0c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -201,7 +201,7 @@ class DataStore( name: Optional[str] = None, guests: bool = True, deactivated: bool = False, - order_by: str = UserSortOrder.USER_ID.value, + order_by: str = UserSortOrder.NAME.value, direction: str = "f", approved: bool = True, ) -> Tuple[List[JsonDict], int]: @@ -261,6 +261,7 @@ class DataStore( sql_base = f""" FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? + LEFT JOIN erased_users AS eu ON u.name = eu.user_id {where_clause} """ sql = "SELECT COUNT(*) as total_users " + sql_base @@ -269,7 +270,8 @@ class DataStore( sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, - displayname, avatar_url, creation_ts * 1000 as creation_ts, approved + displayname, avatar_url, creation_ts * 1000 as creation_ts, approved, + eu.user_id is not null as erased {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? @@ -277,6 +279,13 @@ class DataStore( args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) + + # some of those boolean values are returned as integers when we're on SQLite + columns_to_boolify = ["erased"] + for user in users: + for column in columns_to_boolify: + user[column] = bool(user[column]) + return users, count return await self.db_pool.runInteraction( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4c1ce33463..63410ffdf1 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -31,7 +31,7 @@ from synapse.api.room_versions import RoomVersions from synapse.rest.client import devices, login, logout, profile, register, room, sync from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests import unittest @@ -924,6 +924,36 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids) self.assertEqual(not_approved_user, non_admin_user_ids[0]) + def test_erasure_status(self) -> None: + # Create a new user. + user_id = self.register_user("eraseme", "eraseme") + + # They should appear in the list users API, marked as not erased. + channel = self.make_request( + "GET", + self.url + "?deactivated=true", + access_token=self.admin_user_tok, + ) + users = {user["name"]: user for user in channel.json_body["users"]} + self.assertIs(users[user_id]["erased"], False) + + # Deactivate that user, requesting erasure. + deactivate_account_handler = self.hs.get_deactivate_account_handler() + self.get_success( + deactivate_account_handler.deactivate_account( + user_id, erase_data=True, requester=create_requester(user_id) + ) + ) + + # Repeat the list users query. They should now be marked as erased. + channel = self.make_request( + "GET", + self.url + "?deactivated=true", + access_token=self.admin_user_tok, + ) + users = {user["name"]: user for user in channel.json_body["users"]} + self.assertIs(users[user_id]["erased"], True) + def _order_test( self, expected_user_list: List[str], @@ -1195,6 +1225,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User1", channel.json_body["displayname"]) + self.assertFalse(channel.json_body["erased"]) # Deactivate and erase user channel = self.make_request( @@ -1219,6 +1250,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(0, len(channel.json_body["threepids"])) self.assertIsNone(channel.json_body["avatar_url"]) self.assertIsNone(channel.json_body["displayname"]) + self.assertTrue(channel.json_body["erased"]) self._is_erased("@user:test", True) @@ -2757,6 +2789,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertIn("avatar_url", content) self.assertIn("admin", content) self.assertIn("deactivated", content) + self.assertIn("erased", content) self.assertIn("shadow_banned", content) self.assertIn("creation_ts", content) self.assertIn("appservice_id", content) -- cgit 1.4.1 From 4dd7aa371b6bc746fa4b0a9af220b2013b17a45d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Oct 2022 09:11:19 -0400 Subject: Properly update the threads table when thread events are redacted. (#14248) When the last event in a thread is redacted we need to update the threads table: * Find the new latest event in the thread and store it into the table; or * Remove the thread from the table if it is no longer a thread (i.e. all events in the thread were redacted). --- changelog.d/14248.bugfix | 1 + synapse/storage/databases/main/events.py | 61 ++++++++++++++--- tests/rest/client/test_relations.py | 110 +++++++++++++++++++++---------- 3 files changed, 129 insertions(+), 43 deletions(-) create mode 100644 changelog.d/14248.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14248.bugfix b/changelog.d/14248.bugfix new file mode 100644 index 0000000000..203c52c16b --- /dev/null +++ b/changelog.d/14248.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0rc1 where the information returned from the `/threads` API could be stale when threaded events are redacted. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 6698cbf664..00880bb37d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2028,25 +2028,37 @@ class PersistEventsStore: redacted_event_id: The event that was redacted. """ - # Fetch the current relation of the event being redacted. - redacted_relates_to = self.db_pool.simple_select_one_onecol_txn( + # Fetch the relation of the event being redacted. + row = self.db_pool.simple_select_one_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id}, - retcol="relates_to_id", + retcols=("relates_to_id", "relation_type"), allow_none=True, ) + # Nothing to do if no relation is found. + if row is None: + return + + redacted_relates_to = row["relates_to_id"] + rel_type = row["relation_type"] + self.db_pool.simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} + ) + # Any relation information for the related event must be cleared. - if redacted_relates_to is not None: - self.store._invalidate_cache_and_stream( - txn, self.store.get_relations_for_event, (redacted_relates_to,) - ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_relations_for_event, (redacted_relates_to,) + ) + if rel_type == RelationTypes.ANNOTATION: self.store._invalidate_cache_and_stream( txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REPLACE: self.store._invalidate_cache_and_stream( txn, self.store.get_applicable_edit, (redacted_relates_to,) ) + if rel_type == RelationTypes.THREAD: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_summary, (redacted_relates_to,) ) @@ -2057,9 +2069,38 @@ class PersistEventsStore: txn, self.store.get_threads, (room_id,) ) - self.db_pool.simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) + # Find the new latest event in the thread. + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND relation_type = ? + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT 1 + """ + txn.execute(sql, (redacted_relates_to, RelationTypes.THREAD)) + + # If a latest event is found, update the threads table, this might + # be the same current latest event (if an earlier event in the thread + # was redacted). + latest_event_row = txn.fetchone() + if latest_event_row: + self.db_pool.simple_upsert_txn( + txn, + table="threads", + keyvalues={"room_id": room_id, "thread_id": redacted_relates_to}, + values={ + "latest_event_id": latest_event_row[0], + "topological_ordering": latest_event_row[1], + "stream_ordering": latest_event_row[2], + }, + ) + + # Otherwise, delete the thread: it no longer exists. + else: + self.db_pool.simple_delete_one_txn( + txn, table="threads", keyvalues={"thread_id": redacted_relates_to} + ) def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("topic"), str): diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index ddf315b894..e3d801f7a8 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1523,6 +1523,26 @@ class RelationRedactionTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) + def _get_threads(self) -> List[Tuple[str, str]]: + """Request the threads in the room and returns a list of thread ID and latest event ID.""" + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + threads = channel.json_body["chunk"] + return [ + ( + t["event_id"], + t["unsigned"]["m.relations"][RelationTypes.THREAD]["latest_event"][ + "event_id" + ], + ) + for t in threads + ] + def test_redact_relation_annotation(self) -> None: """ Test that annotations of an event are properly handled after the @@ -1567,58 +1587,82 @@ class RelationRedactionTestCase(BaseRelationsTestCase): The redacted event should not be included in bundled aggregations or the response to relations. """ - channel = self._send_relation( - RelationTypes.THREAD, - EventTypes.Message, - content={"body": "reply 1", "msgtype": "m.text"}, - ) - unredacted_event_id = channel.json_body["event_id"] + # Create a thread with a few events in it. + thread_replies = [] + for i in range(3): + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": f"reply {i}", "msgtype": "m.text"}, + ) + thread_replies.append(channel.json_body["event_id"]) - # Note that the *last* event in the thread is redacted, as that gets - # included in the bundled aggregation. - channel = self._send_relation( - RelationTypes.THREAD, - EventTypes.Message, - content={"body": "reply 2", "msgtype": "m.text"}, + ################################################## + # Check the test data is configured as expected. # + ################################################## + self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) + relations = self._get_bundled_aggregations() + self.assertDictContainsSubset( + {"count": 3, "current_user_participated": True}, + relations[RelationTypes.THREAD], + ) + # The latest event is the last sent event. + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + thread_replies[-1], ) - to_redact_event_id = channel.json_body["event_id"] - # Both relations exist. - event_ids = self._get_related_events() + # There should be one thread, the latest event is the event that will be redacted. + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) + + ########################## + # Redact the last event. # + ########################## + self._redact(thread_replies.pop()) + + # The thread should still exist, but the latest event should be updated. + self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) self.assertDictContainsSubset( - { - "count": 2, - "current_user_participated": True, - }, + {"count": 2, "current_user_participated": True}, relations[RelationTypes.THREAD], ) - # And the latest event returned is the event that will be redacted. + # And the latest event is the last unredacted event. self.assertEqual( relations[RelationTypes.THREAD]["latest_event"]["event_id"], - to_redact_event_id, + thread_replies[-1], ) + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) - # Redact one of the reactions. - self._redact(to_redact_event_id) + ########################################### + # Redact the *first* event in the thread. # + ########################################### + self._redact(thread_replies.pop(0)) - # The unredacted relation should still exist. - event_ids = self._get_related_events() + # Nothing should have changed (except the thread count). + self.assertEquals(self._get_related_events(), thread_replies) relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [unredacted_event_id]) self.assertDictContainsSubset( - { - "count": 1, - "current_user_participated": True, - }, + {"count": 1, "current_user_participated": True}, relations[RelationTypes.THREAD], ) - # And the latest event is now the unredacted event. + # And the latest event is the last unredacted event. self.assertEqual( relations[RelationTypes.THREAD]["latest_event"]["event_id"], - unredacted_event_id, + thread_replies[-1], ) + self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])]) + + #################################### + # Redact the last remaining event. # + #################################### + self._redact(thread_replies.pop(0)) + self.assertEquals(thread_replies, []) + + # The event should no longer be considered a thread. + self.assertEquals(self._get_related_events(), []) + self.assertEquals(self._get_bundled_aggregations(), {}) + self.assertEqual(self._get_threads(), []) def test_redact_parent_edit(self) -> None: """Test that edits of an event are redacted when the original event -- cgit 1.4.1 From 9192d74b0bf2f87b00d3e106a18baa9ce27acda1 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Oct 2022 16:25:02 +0200 Subject: Refactor OIDC tests to better mimic an actual OIDC provider. (#13910) This implements a fake OIDC server, which intercepts calls to the HTTP client. Improves accuracy of tests by covering more internal methods. One particular example was the ID token validation, which previously mocked. This uncovered an incorrect dependency: Synapse actually requires at least authlib 0.15.1, not 0.14.0. --- changelog.d/13910.misc | 1 + pyproject.toml | 2 +- synapse/handlers/oidc.py | 15 +- tests/federation/test_federation_client.py | 36 +- tests/handlers/test_oidc.py | 580 +++++++++++++---------------- tests/rest/client/test_auth.py | 32 +- tests/rest/client/test_login.py | 40 +- tests/rest/client/utils.py | 136 +++---- tests/test_utils/__init__.py | 40 +- tests/test_utils/oidc.py | 325 ++++++++++++++++ 10 files changed, 747 insertions(+), 460 deletions(-) create mode 100644 changelog.d/13910.misc create mode 100644 tests/test_utils/oidc.py (limited to 'tests/rest') diff --git a/changelog.d/13910.misc b/changelog.d/13910.misc new file mode 100644 index 0000000000..e906952aab --- /dev/null +++ b/changelog.d/13910.misc @@ -0,0 +1 @@ +Refactor OIDC tests to better mimic an actual OIDC provider. diff --git a/pyproject.toml b/pyproject.toml index 6ebac41ed1..7e0feb75aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,7 +192,7 @@ psycopg2 = { version = ">=2.8", markers = "platform_python_implementation != 'Py psycopg2cffi = { version = ">=2.8", markers = "platform_python_implementation == 'PyPy'", optional = true } psycopg2cffi-compat = { version = "==1.1", markers = "platform_python_implementation == 'PyPy'", optional = true } pysaml2 = { version = ">=4.5.0", optional = true } -authlib = { version = ">=0.14.0", optional = true } +authlib = { version = ">=0.15.1", optional = true } # systemd-python is necessary for logging to the systemd journal via # `systemd.journal.JournalHandler`, as is documented in # `contrib/systemd/log_config.yaml`. diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index d7a8226900..9759daf043 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -275,6 +275,7 @@ class OidcProvider: provider: OidcProviderConfig, ): self._store = hs.get_datastores().main + self._clock = hs.get_clock() self._macaroon_generaton = macaroon_generator @@ -673,6 +674,13 @@ class OidcProvider: Returns: The decoded claims in the ID token. """ + id_token = token.get("id_token") + logger.debug("Attempting to decode JWT id_token %r", id_token) + + # That has been theoritically been checked by the caller, so even though + # assertion are not enabled in production, it is mainly here to appease mypy + assert id_token is not None + metadata = await self.load_metadata() claims_params = { "nonce": nonce, @@ -688,9 +696,6 @@ class OidcProvider: claim_options = {"iss": {"values": [metadata["issuer"]]}} - id_token = token["id_token"] - logger.debug("Attempting to decode JWT id_token %r", id_token) - # Try to decode the keys in cache first, then retry by forcing the keys # to be reloaded jwk_set = await self.load_jwks() @@ -715,7 +720,9 @@ class OidcProvider: logger.debug("Decoded id_token JWT %r; validating", claims) - claims.validate(leeway=120) # allows 2 min of clock skew + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew return claims diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index a538215931..51d3bb8fff 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from unittest import mock import twisted.web.client from twisted.internet import defer -from twisted.internet.protocol import Protocol -from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import RoomVersions @@ -26,10 +23,9 @@ from synapse.events import EventBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict from synapse.util import Clock -from tests.test_utils import event_injection +from tests.test_utils import FakeResponse, event_injection from tests.unittest import FederatingHomeserverTestCase @@ -98,8 +94,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "pdus": [ create_event_dict, member_event_dict, @@ -208,8 +204,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "origin": "yet.another.server", "origin_server_ts": 900, "pdus": [ @@ -269,8 +265,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # We expect an outbound request to /backfill, so stub that out self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "origin": "yet.another.server", "origin_server_ts": 900, # Mimic the other server returning our new `pulled_event` @@ -305,21 +301,3 @@ class FederationClientTest(FederatingHomeserverTestCase): # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the # other from "yet.another.server" self.assertEqual(backfill_num_attempts, 2) - - -def _mock_response(resp: JsonDict): - body = json.dumps(resp).encode("utf-8") - - def deliver_body(p: Protocol): - p.dataReceived(body) - p.connectionLost(Failure(twisted.web.client.ResponseDone())) - - response = mock.Mock( - code=200, - phrase=b"OK", - headers=twisted.web.client.Headers({"content-Type": ["application/json"]}), - length=len(body), - deliverBody=deliver_body, - ) - mock.seal(response) - return response diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e6cd3af7b7..5955410524 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os -from typing import Any, Dict +from typing import Any, Dict, Tuple from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse @@ -22,12 +21,15 @@ import pymacaroons from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.sso import MappingException +from synapse.http.site import SynapseRequest from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import UserID from synapse.util import Clock -from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon +from synapse.util.macaroons import get_value_from_macaroon +from synapse.util.stringutils import random_string from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer from tests.unittest import HomeserverTestCase, override_config try: @@ -46,12 +48,6 @@ BASE_URL = "https://synapse/" CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] -AUTHORIZATION_ENDPOINT = ISSUER + "authorize" -TOKEN_ENDPOINT = ISSUER + "token" -USERINFO_ENDPOINT = ISSUER + "userinfo" -WELL_KNOWN = ISSUER + ".well-known/openid-configuration" -JWKS_URI = ISSUER + ".well-known/jwks.json" - # config for common cases DEFAULT_CONFIG = { "enabled": True, @@ -66,9 +62,9 @@ DEFAULT_CONFIG = { EXPLICIT_ENDPOINT_CONFIG = { **DEFAULT_CONFIG, "discover": False, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, + "authorization_endpoint": ISSUER + "authorize", + "token_endpoint": ISSUER + "token", + "jwks_uri": ISSUER + "jwks", } @@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider): } -async def get_json(url: str) -> JsonDict: - # Mock get_json calls to handle jwks & oidc discovery endpoints - if url == WELL_KNOWN: - # Minimal discovery document, as defined in OpenID.Discovery - # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata - return { - "issuer": ISSUER, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, - "userinfo_endpoint": USERINFO_ENDPOINT, - "response_types_supported": ["code"], - "subject_types_supported": ["public"], - "id_token_signing_alg_values_supported": ["RS256"], - } - elif url == JWKS_URI: - return {"keys": []} - - return {} - - def _key_file_path() -> str: """path to a file containing the private half of a test key""" @@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase): return config def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.http_client = Mock(spec=["get_json"]) - self.http_client.get_json.side_effect = get_json - self.http_client.user_agent = b"Synapse Test" + self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER) - hs = self.setup_test_homeserver(proxied_http_client=self.http_client) + hs = self.setup_test_homeserver() + self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) + self.hs_patcher.start() self.handler = hs.get_oidc_handler() self.provider = self.handler._providers["oidc"] @@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase): # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 + auth_handler = hs.get_auth_handler() + # Mock the complete SSO login method. + self.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] + return hs + def tearDown(self) -> None: + self.hs_patcher.stop() + return super().tearDown() + + def reset_mocks(self): + """Reset all the Mocks.""" + self.fake_server.reset_mocks() + self.render_error.reset_mock() + self.complete_sso_login.reset_mock() + def metadata_edit(self, values): """Modify the result that will be returned by the well-known query""" - async def patched_get_json(uri): - res = await get_json(uri) - if uri == WELL_KNOWN: - res.update(values) - return res + metadata = self.fake_server.get_metadata() + metadata.update(values) + return patch.object(self.fake_server, "get_metadata", return_value=metadata) - return patch.object(self.http_client, "get_json", patched_get_json) + def start_authorization( + self, + userinfo: dict, + client_redirect_url: str = "http://client/redirect", + scope: str = "openid", + with_sid: bool = False, + ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]: + """Start an authorization request, and get the callback request back.""" + nonce = random_string(10) + state = random_string(10) + + code, grant = self.fake_server.start_authorization( + userinfo=userinfo, + scope=scope, + client_id=self.provider._client_auth.client_id, + redirect_uri=self.provider._callback_url, + nonce=nonce, + with_sid=with_sid, + ) + session = self._generate_oidc_session_token(state, nonce, client_redirect_url) + return _build_callback_request(code, state, session), grant def assertRenderedError(self, error, error_description=None): self.render_error.assert_called_once() @@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.fake_server.get_metadata_handler.assert_called_once() - self.assertEqual(metadata.issuer, ISSUER) - self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT) - self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT) - self.assertEqual(metadata.jwks_uri, JWKS_URI) - # FIXME: it seems like authlib does not have that defined in its metadata models - # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT) + self.assertEqual(metadata.issuer, self.fake_server.issuer) + self.assertEqual( + metadata.authorization_endpoint, + self.fake_server.authorization_endpoint, + ) + self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint) + self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri) + # It seems like authlib does not have that defined in its metadata models + self.assertEqual( + metadata.get("userinfo_endpoint"), + self.fake_server.userinfo_endpoint, + ) # subsequent calls should be cached - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() - @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_called_once_with(JWKS_URI) - self.assertEqual(jwks, {"keys": []}) + self.fake_server.get_jwks_handler.assert_called_once() + self.assertEqual(jwks, self.fake_server.get_jwks()) # subsequent calls should be cached… - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_not_called() + self.fake_server.get_jwks_handler.assert_not_called() # …unless forced - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks(force=True)) - self.http_client.get_json.assert_called_once_with(JWKS_URI) + self.fake_server.get_jwks_handler.assert_called_once() - # Throw if the JWKS uri is missing - original = self.provider.load_metadata - - async def patched_load_metadata(): - m = (await original()).copy() - m.update({"jwks_uri": None}) - return m - - with patch.object(self.provider, "load_metadata", patched_load_metadata): + with self.metadata_edit({"jwks_uri": None}): + # If we don't do this, the load_metadata call will throw because of the + # missing jwks_uri + self.provider._user_profile_method = "userinfo_endpoint" + self.get_success(self.provider.load_metadata(force=True)) self.get_failure(self.provider.load_jwks(force=True), RuntimeError) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.provider.handle_redirect_request(req, b"http://client/redirect") ) ) - auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) + auth_endpoint = urlparse(self.fake_server.authorization_endpoint) self.assertEqual(url.scheme, auth_endpoint.scheme) self.assertEqual(url.netloc, auth_endpoint.netloc) @@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase): with self.assertRaises(AttributeError): _ = mapping_provider.get_extra_attributes - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } username = "bar" userinfo = { "sub": "foo", "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - code = "code" - state = "state" - nonce = "nonce" client_redirect_url = "http://client/redirect" - ip_address = "10.0.0.1" - session = self._generate_oidc_session_token(state, nonce, client_redirect_url) - request = _build_callback_request(code, state, session, ip_address=ip_address) - + request, _ = self.start_authorization( + userinfo, client_redirect_url=client_redirect_url + ) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, client_redirect_url, None, new_user=True, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_not_called() + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors + request, _ = self.start_authorization(userinfo) with patch.object( self.provider, "_remote_id_from_userinfo", @@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("mapping_error") # Handle ID token errors - self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") - auth_handler.complete_sso_login.reset_mock() - self.provider._exchange_code.reset_mock() - self.provider._parse_id_token.reset_mock() - self.provider._fetch_userinfo.reset_mock() + self.reset_mocks() # With userinfo fetching self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + # Without the "openid" scope, the FakeProvider does not generate an id_token + request, _ = self.start_authorization(userinfo, scope="") self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_not_called() - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() + self.reset_mocks() + # With an ID token, userinfo fetching and sid in the ID token self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - "id_token": "id_token", - } - id_token = { - "sid": "abcdefgh", - } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment] - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - auth_handler.complete_sso_login.reset_mock() - self.provider._fetch_userinfo.reset_mock() + request, grant = self.start_authorization(userinfo, with_sid=True) + self.assertIsNotNone(grant.sid) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, - auth_provider_session_id=id_token["sid"], + auth_provider_session_id=grant.sid, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() # Handle userinfo fetching error - self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_server.buggy_endpoint(userinfo=True): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") - # Handle code exchange failure - from synapse.handlers.oidc import OidcError - - self.provider._exchange_code = simple_async_mock( # type: ignore[assignment] - raises=OidcError("invalid_request") - ) - self.get_success(self.handler.handle_oidc_callback(request)) - self.assertRenderedError("invalid_request") + request, _ = self.start_authorization(userinfo) + with self.fake_server.buggy_endpoint(token=True): + self.get_success(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("server_error") @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_session(self) -> None: @@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_exchange_code(self) -> None: """Code exchange behaves correctly and handles various error scenarios.""" - token = {"type": "bearer"} - token_json = json.dumps(token).encode("utf-8") - self.http_client.request = simple_async_mock( - return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(ret, token) self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) args = parse_qs(kwargs["data"].decode("utf-8")) self.assertEqual(args["grant_type"], ["authorization_code"]) @@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) # Test error handling - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad Request", - body=b'{"error": "foo", "error_description": "bar"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=400, payload={"error": "foo", "error_description": "bar"} ) from synapse.handlers.oidc import OidcError @@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(exc.value.error_description, "bar") # Internal server error with no JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b"Not JSON", - ) + self.fake_server.post_token_handler.return_value = FakeResponse( + code=500, body=b"Not JSON" ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b'{"error": "internal_server_error"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=500, payload={"error": "internal_server_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad request", - body=b"{}", - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=400, payload={} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, - phrase=b"OK", - body=b'{"error": "some_error"}', - ) + self.fake_server.post_token_handler.return_value = FakeResponse.json( + code=200, payload={"error": "some_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "some_error") @@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase): """Test that code exchange works with a JWK client secret.""" from authlib.jose import jwt - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" @@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # the client secret provided to the should be a jwt which can be checked with # the public key @@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_exchange_code_no_auth(self) -> None: """Test that code exchange works with no client secret.""" - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) @@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # check the POSTed data args = parse_qs(kwargs["data"].decode("utf-8")) @@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase): """ Login while using a mapping provider that implements get_extra_attributes. """ - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } userinfo = { "sub": "foo", "username": "foo", "phone": "1234567", } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - - state = "state" - client_redirect_url = "http://client/redirect" - session = self._generate_oidc_session_token( - state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ) - request = _build_callback_request("code", state, session) - + request, _ = self.start_authorization(userinfo) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@foo:test", - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, {"phone": "1234567"}, new_user=True, auth_provider_session_id=None, @@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_user(self) -> None: """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - userinfo: dict = { "sub": "test_user", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Some providers return an integer ID. userinfo = { "sub": 1234, "username": "test_user_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Test if the mxid is already taken store = self.hs.get_datastores().main @@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Mapping provider does not support de-duplicating Matrix IDs", @@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user.to_string(), password_hash=None) ) - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # Map a user via SSO. userinfo = { "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Subsequent calls should map to the same mxid. - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Note that a second SSO user can be mapped to the same Matrix ID. (This # requires a unique sub, but something that maps to the same matrix ID, @@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test1", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register some non-exact matching cases. user2 = UserID.from_string("@TEST_user_2:test") @@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test2", "username": "TEST_USER_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() args = self.assertRenderedError("mapping_error") self.assertTrue( args[2].startswith( @@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user2.to_string(), password_hash=None) ) - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@TEST_USER_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, @@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" - self.get_success( - _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) - ) + userinfo = {"sub": "test2", "username": "föö"} + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: föö") @override_config( @@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_map_userinfo_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) @@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # test_user is already taken, so test_user1 gets registered instead. - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@test_user1:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register all of the potential mxids for a particular OIDC username. self.get_success( @@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) @@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_attribute_requirements(self) -> None: """The required attributes must be met from the OIDC userinfo response.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # userinfo lacking "test": "foobar" attribute should fail. userinfo = { "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": "foobar" attribute should succeed. userinfo = { @@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": "foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_attribute_requirements_contains(self) -> None: """Test that auth succeeds if userinfo attribute CONTAINS required value""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed. userinfo = { "sub": "tester", "username": "tester", "test": ["foobar", "foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase): Test that auth fails if attributes exist but don't match, or are non-string values. """ - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": "not_foobar" attribute should fail userinfo: dict = { "sub": "tester", "username": "tester", "test": "not_foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": ["foo", "bar"] attribute should fail userinfo = { @@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": ["foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": False attribute should fail # this is largely just to ensure we don't crash here @@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": False, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": None attribute should fail # a value of None breaks the OIDC spec, but it's important to not crash here @@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 1 attribute should fail # this is largely just to ensure we don't crash here @@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": 1, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 3.14 attribute should fail # this is largely just to ensure we don't crash here @@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "tester", "test": 3.14, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() def _generate_oidc_session_token( self, @@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return self.handler._macaroon_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( - idp_id="oidc", + idp_id=self.provider.idp_id, nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, @@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) -async def _make_callback_with_userinfo( - hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" -) -> None: - """Mock up an OIDC callback with the given userinfo dict - - We'll pull out the OIDC handler from the homeserver, stub out a couple of methods, - and poke in the userinfo dict as if it were the response to an OIDC userinfo call. - - Args: - hs: the HomeServer impl to send the callback to. - userinfo: the OIDC userinfo dict - client_redirect_url: the URL to redirect to on success. - """ - - handler = hs.get_oidc_handler() - provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment] - provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - - state = "state" - session = handler._macaroon_generator.generate_oidc_session_token( - state=state, - session_data=OidcSessionData( - idp_id="oidc", - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id="", - ), - ) - request = _build_callback_request("code", state, session) - - await handler.handle_oidc_callback(request) - - def _build_callback_request( code: str, state: str, diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 090cef5216..ebf653d018 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -465,9 +465,11 @@ class UIAuthTests(unittest.HomeserverTestCase): * checking that the original operation succeeds """ + fake_oidc_server = self.helper.fake_oidc_server() + # log the user in remote_user_id = UserID.from_string(self.user).localpart - login_resp = self.helper.login_via_oidc(remote_user_id) + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id) self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device @@ -481,8 +483,8 @@ class UIAuthTests(unittest.HomeserverTestCase): # run the UIA-via-SSO flow session_id = channel.json_body["session"] - channel = self.helper.auth_via_oidc( - {"sub": remote_user_id}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id ) # that should serve a confirmation page @@ -499,7 +501,8 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self) -> None: - login_resp = self.helper.login_via_oidc("username") + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -522,7 +525,10 @@ class UIAuthTests(unittest.HomeserverTestCase): @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) channel = self.delete_device( @@ -539,8 +545,13 @@ class UIAuthTests(unittest.HomeserverTestCase): @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" + + fake_oidc_server = self.helper.fake_oidc_server() + # log the user in - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) # start a UI Auth flow by attempting to delete a device @@ -553,8 +564,8 @@ class UIAuthTests(unittest.HomeserverTestCase): session_id = channel.json_body["session"] # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc( - {"sub": "wrong_user"}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id ) # that should return a failure message @@ -584,7 +595,10 @@ class UIAuthTests(unittest.HomeserverTestCase): """Tests that if we register a user via SSO while requiring approval for new accounts, we still raise the correct error before logging the user in. """ - login_resp = self.helper.login_via_oidc("username", expected_status=403) + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_server, "username", expected_status=403 + ) self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL) self.assertEqual( diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index e801ba8c8b..ff5baa9f0a 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -36,7 +36,7 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 -from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.rest.client.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless @@ -612,13 +612,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - # pick the default OIDC provider - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=oidc", - ) + fake_oidc_server = self.helper.fake_oidc_server() + + with fake_oidc_server.patch_homeserver(hs=self.hs): + # pick the default OIDC provider + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=oidc", + ) self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -626,7 +629,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") @@ -643,7 +646,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): TEST_CLIENT_REDIRECT_URL, ) - channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + channel, _ = self.helper.complete_oidc_auth( + fake_oidc_server, oidc_uri, cookies, {"sub": "user1"} + ) # that should serve a confirmation page self.assertEqual(channel.code, 200, channel.result) @@ -693,7 +698,10 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" - channel = self._make_sso_redirect_request("oidc") + fake_oidc_server = self.helper.fake_oidc_server() + + with fake_oidc_server.patch_homeserver(hs=self.hs): + channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -701,7 +709,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect @@ -1280,9 +1288,13 @@ class UsernamePickerTestCase(HomeserverTestCase): def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" + fake_oidc_server = self.helper.fake_oidc_server() + # do the start of the login flow - channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL + channel, _ = self.helper.auth_via_oidc( + fake_oidc_server, + {"sub": "tester", "displayname": "Jonny"}, + TEST_CLIENT_REDIRECT_URL, ) # that should redirect to the username picker diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index c249a42bb6..967d229223 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -31,7 +31,6 @@ from typing import ( Tuple, overload, ) -from unittest.mock import patch from urllib.parse import urlencode import attr @@ -46,8 +45,19 @@ from synapse.server import HomeServer from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request -from tests.test_utils import FakeResponse from tests.test_utils.html_parsers import TestHtmlParser +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer + +# an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_ISSUER = "https://issuer.test/" +TEST_OIDC_CONFIG = { + "enabled": True, + "issuer": TEST_OIDC_ISSUER, + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["openid"], + "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, +} @attr.s(auto_attribs=True) @@ -543,12 +553,28 @@ class RestHelper: return channel.json_body + def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: + """Create a ``FakeOidcServer``. + + This can be used in conjuction with ``login_via_oidc``:: + + fake_oidc_server = self.helper.fake_oidc_server() + login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user") + """ + + return FakeOidcServer( + clock=self.hs.get_clock(), + issuer=issuer, + ) + def login_via_oidc( self, + fake_server: FakeOidcServer, remote_user_id: str, + with_sid: bool = False, expected_status: int = 200, - ) -> JsonDict: - """Log in via OIDC + ) -> Tuple[JsonDict, FakeAuthorizationGrant]: + """Log in (as a new user) via OIDC Returns the result of the final token login. @@ -560,7 +586,10 @@ class RestHelper: the normal places. """ client_redirect_url = "https://x" - channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) + userinfo = {"sub": remote_user_id} + channel, grant = self.auth_via_oidc( + fake_server, userinfo, client_redirect_url, with_sid=with_sid + ) # expect a confirmation page assert channel.code == HTTPStatus.OK, channel.result @@ -585,14 +614,16 @@ class RestHelper: assert ( channel.code == expected_status ), f"unexpected status in response: {channel.code}" - return channel.json_body + return channel.json_body, grant def auth_via_oidc( self, + fake_server: FakeOidcServer, user_info_dict: JsonDict, client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Perform an OIDC authentication flow via a mock OIDC provider. This can be used for either login or user-interactive auth. @@ -616,6 +647,7 @@ class RestHelper: the login redirect endpoint ui_auth_session_id: if set, we will perform a UI Auth flow. The session id of the UI auth. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. @@ -625,14 +657,15 @@ class RestHelper: cookies: Dict[str, str] = {} - # if we're doing a ui auth, hit the ui auth redirect endpoint - if ui_auth_session_id: - # can't set the client redirect url for UI Auth - assert client_redirect_url is None - oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) - else: - # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + with fake_server.patch_homeserver(hs=self.hs): + # if we're doing a ui auth, hit the ui auth redirect endpoint + if ui_auth_session_id: + # can't set the client redirect url for UI Auth + assert client_redirect_url is None + oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + else: + # otherwise, hit the login redirect endpoint + oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" @@ -640,17 +673,21 @@ class RestHelper: # that synapse passes to the client. oauth_uri_path, _ = oauth_uri.split("?", 1) - assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( + assert oauth_uri_path == fake_server.authorization_endpoint, ( "unexpected SSO URI " + oauth_uri_path ) - return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + return self.complete_oidc_auth( + fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid + ) def complete_oidc_auth( self, + fake_serer: FakeOidcServer, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Mock out an OIDC authentication flow Assumes that an OIDC auth has been initiated by one of initiate_sso_login or @@ -661,50 +698,37 @@ class RestHelper: Requires the OIDC callback resource to be mounted at the normal place. Args: + fake_server: the fake OIDC server with which the auth should be done oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, from initiate_sso_login or initiate_sso_ui_auth). cookies: the cookies set by synapse's redirect endpoint, which will be sent back to the callback endpoint. user_info_dict: the remote userinfo that the OIDC provider should present. Typically this should be '{"sub": ""}'. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. """ _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) + + code, grant = fake_serer.start_authorization( + scope=params["scope"][0], + userinfo=user_info_dict, + client_id=params["client_id"][0], + redirect_uri=params["redirect_uri"][0], + nonce=params["nonce"][0], + with_sid=with_sid, + ) + state = params["state"][0] + callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, - urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), + urllib.parse.urlencode({"state": state, "code": code}), ) - # before we hit the callback uri, stub out some methods in the http client so - # that we don't have to handle full HTTPS requests. - # (expected url, json response) pairs, in the order we expect them. - expected_requests = [ - # first we get a hit to the token endpoint, which we tell to return - # a dummy OIDC access token - (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), - # and then one to the user_info endpoint, which returns our remote user id. - (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), - ] - - async def mock_req( - method: str, - uri: str, - data: Optional[dict] = None, - headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): - (expected_uri, resp_obj) = expected_requests.pop(0) - assert uri == expected_uri - resp = FakeResponse( - code=HTTPStatus.OK, - phrase=b"OK", - body=json.dumps(resp_obj).encode("utf-8"), - ) - return resp - - with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): + with fake_serer.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code channel = make_request( self.hs.get_reactor(), @@ -715,7 +739,7 @@ class RestHelper: ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() ], ) - return channel + return channel, grant def initiate_sso_login( self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] @@ -806,21 +830,3 @@ class RestHelper: assert len(p.links) == 1, "not exactly one link in confirmation page" oauth_uri = p.links[0] return oauth_uri - - -# an 'oidc_config' suitable for login_via_oidc. -TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" -TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" -TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" -TEST_OIDC_CONFIG = { - "enabled": True, - "discover": False, - "issuer": "https://issuer.test", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "scopes": ["profile"], - "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, - "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, - "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, -} diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 0d0d6faf0d..e62ebcc6a5 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -15,17 +15,24 @@ """ Utilities for running the unit tests """ +import json import sys import warnings from asyncio import Future from binascii import unhexlify -from typing import Awaitable, Callable, TypeVar +from typing import Awaitable, Callable, Tuple, TypeVar from unittest.mock import Mock import attr +import zope.interface from twisted.python.failure import Failure from twisted.web.client import ResponseDone +from twisted.web.http import RESPONSES +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.types import JsonDict TV = TypeVar("TV") @@ -97,27 +104,44 @@ def simple_async_mock(return_value=None, raises=None) -> Mock: return Mock(side_effect=cb) -@attr.s -class FakeResponse: +# Type ignore: it does not fully implement IResponse, but is good enough for tests +@zope.interface.implementer(IResponse) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeResponse: # type: ignore[misc] """A fake twisted.web.IResponse object there is a similar class at treq.test.test_response, but it lacks a `phrase` attribute, and didn't support deliverBody until recently. """ - # HTTP response code - code = attr.ib(type=int) + version: Tuple[bytes, int, int] = (b"HTTP", 1, 1) - # HTTP response phrase (eg b'OK' for a 200) - phrase = attr.ib(type=bytes) + # HTTP response code + code: int = 200 # body of the response - body = attr.ib(type=bytes) + body: bytes = b"" + + headers: Headers = attr.Factory(Headers) + + @property + def phrase(self): + return RESPONSES.get(self.code, b"Unknown Status") + + @property + def length(self): + return len(self.body) def deliverBody(self, protocol): protocol.dataReceived(self.body) protocol.connectionLost(Failure(ResponseDone())) + @classmethod + def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse": + headers = Headers({"Content-Type": ["application/json"]}) + body = json.dumps(payload).encode("utf-8") + return cls(code=code, body=body, headers=headers) + # A small image used in some tests. # diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py new file mode 100644 index 0000000000..de134bbc89 --- /dev/null +++ b/tests/test_utils/oidc.py @@ -0,0 +1,325 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import Mock, patch +from urllib.parse import parse_qs + +import attr + +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.server import HomeServer +from synapse.util import Clock +from synapse.util.stringutils import random_string + +from tests.test_utils import FakeResponse + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeAuthorizationGrant: + userinfo: dict + client_id: str + redirect_uri: str + scope: str + nonce: Optional[str] + sid: Optional[str] + + +class FakeOidcServer: + """A fake OpenID Connect Provider.""" + + # All methods here are mocks, so we can track when they are called, and override + # their values + request: Mock + get_jwks_handler: Mock + get_metadata_handler: Mock + get_userinfo_handler: Mock + post_token_handler: Mock + + def __init__(self, clock: Clock, issuer: str): + from authlib.jose import ECKey, KeySet + + self._clock = clock + self.issuer = issuer + + self.request = Mock(side_effect=self._request) + self.get_jwks_handler = Mock(side_effect=self._get_jwks_handler) + self.get_metadata_handler = Mock(side_effect=self._get_metadata_handler) + self.get_userinfo_handler = Mock(side_effect=self._get_userinfo_handler) + self.post_token_handler = Mock(side_effect=self._post_token_handler) + + # A code -> grant mapping + self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {} + # An access token -> grant mapping + self._sessions: Dict[str, FakeAuthorizationGrant] = {} + + # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for + # signing JWTs. ECDSA keys are really quick to generate compared to RSA. + self._key = ECKey.generate_key(crv="P-256", is_private=True) + self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))]) + + self._id_token_overrides: Dict[str, Any] = {} + + def reset_mocks(self): + self.request.reset_mock() + self.get_jwks_handler.reset_mock() + self.get_metadata_handler.reset_mock() + self.get_userinfo_handler.reset_mock() + self.post_token_handler.reset_mock() + + def patch_homeserver(self, hs: HomeServer): + """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``. + + This patch should be used whenever the HS is expected to perform request to the + OIDC provider, e.g.:: + + fake_oidc_server = self.helper.fake_oidc_server() + with fake_oidc_server.patch_homeserver(hs): + self.make_request("GET", "/_matrix/client/r0/login/sso/redirect") + """ + return patch.object(hs.get_proxied_http_client(), "request", self.request) + + @property + def authorization_endpoint(self) -> str: + return self.issuer + "authorize" + + @property + def token_endpoint(self) -> str: + return self.issuer + "token" + + @property + def userinfo_endpoint(self) -> str: + return self.issuer + "userinfo" + + @property + def metadata_endpoint(self) -> str: + return self.issuer + ".well-known/openid-configuration" + + @property + def jwks_uri(self) -> str: + return self.issuer + "jwks" + + def get_metadata(self) -> dict: + return { + "issuer": self.issuer, + "authorization_endpoint": self.authorization_endpoint, + "token_endpoint": self.token_endpoint, + "jwks_uri": self.jwks_uri, + "userinfo_endpoint": self.userinfo_endpoint, + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["ES256"], + } + + def get_jwks(self) -> dict: + return self._jwks.as_dict() + + def get_userinfo(self, access_token: str) -> Optional[dict]: + """Given an access token, get the userinfo of the associated session.""" + session = self._sessions.get(access_token, None) + if session is None: + return None + return session.userinfo + + def _sign(self, payload: dict) -> str: + from authlib.jose import JsonWebSignature + + jws = JsonWebSignature() + kid = self.get_jwks()["keys"][0]["kid"] + protected = {"alg": "ES256", "kid": kid} + json_payload = json.dumps(payload) + return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") + + def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: + now = self._clock.time() + id_token = { + **grant.userinfo, + "iss": self.issuer, + "aud": grant.client_id, + "iat": now, + "nbf": now, + "exp": now + 600, + } + + if grant.nonce is not None: + id_token["nonce"] = grant.nonce + + if grant.sid is not None: + id_token["sid"] = grant.sid + + id_token.update(self._id_token_overrides) + + return self._sign(id_token) + + def id_token_override(self, overrides: dict): + """Temporarily patch the ID token generated by the token endpoint.""" + return patch.object(self, "_id_token_overrides", overrides) + + def start_authorization( + self, + client_id: str, + scope: str, + redirect_uri: str, + userinfo: dict, + nonce: Optional[str] = None, + with_sid: bool = False, + ) -> Tuple[str, FakeAuthorizationGrant]: + """Start an authorization request, and get back the code to use on the authorization endpoint.""" + code = random_string(10) + sid = None + if with_sid: + sid = random_string(10) + + grant = FakeAuthorizationGrant( + userinfo=userinfo, + scope=scope, + redirect_uri=redirect_uri, + nonce=nonce, + client_id=client_id, + sid=sid, + ) + self._authorization_grants[code] = grant + + return code, grant + + def exchange_code(self, code: str) -> Optional[Dict[str, Any]]: + grant = self._authorization_grants.pop(code, None) + if grant is None: + return None + + access_token = random_string(10) + self._sessions[access_token] = grant + + token = { + "token_type": "Bearer", + "access_token": access_token, + "expires_in": 3600, + "scope": grant.scope, + } + + if "openid" in grant.scope: + token["id_token"] = self.generate_id_token(grant) + + return dict(token) + + def buggy_endpoint( + self, + *, + jwks: bool = False, + metadata: bool = False, + token: bool = False, + userinfo: bool = False, + ): + """A context which makes a set of endpoints return a 500 error. + + Args: + jwks: If True, makes the JWKS endpoint return a 500 error. + metadata: If True, makes the OIDC Discovery endpoint return a 500 error. + token: If True, makes the token endpoint return a 500 error. + userinfo: If True, makes the userinfo endpoint return a 500 error. + """ + buggy = FakeResponse(code=500, body=b"Internal server error") + + patches = {} + if jwks: + patches["get_jwks_handler"] = Mock(return_value=buggy) + if metadata: + patches["get_metadata_handler"] = Mock(return_value=buggy) + if token: + patches["post_token_handler"] = Mock(return_value=buggy) + if userinfo: + patches["get_userinfo_handler"] = Mock(return_value=buggy) + + return patch.multiple(self, **patches) + + async def _request( + self, + method: str, + uri: str, + data: Optional[bytes] = None, + headers: Optional[Headers] = None, + ) -> IResponse: + """The override of the SimpleHttpClient#request() method""" + access_token: Optional[str] = None + + if headers is None: + headers = Headers() + + # Try to find the access token in the headers if any + auth_headers = headers.getRawHeaders(b"Authorization") + if auth_headers: + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + access_token = parts[1].decode("ascii") + + if method == "POST": + # If the method is POST, assume it has an url-encoded body + if data is None or headers.getRawHeaders(b"Content-Type") != [ + b"application/x-www-form-urlencoded" + ]: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + params = parse_qs(data.decode("utf-8")) + + if uri == self.token_endpoint: + # Even though this endpoint should be protected, this does not check + # for client authentication. We're not checking it for simplicity, + # and because client authentication is tested in other standalone tests. + return self.post_token_handler(params) + + elif method == "GET": + if uri == self.jwks_uri: + return self.get_jwks_handler() + elif uri == self.metadata_endpoint: + return self.get_metadata_handler() + elif uri == self.userinfo_endpoint: + return self.get_userinfo_handler(access_token=access_token) + + return FakeResponse(code=404, body=b"404 not found") + + # Request handlers + def _get_jwks_handler(self) -> IResponse: + """Handles requests to the JWKS URI.""" + return FakeResponse.json(payload=self.get_jwks()) + + def _get_metadata_handler(self) -> IResponse: + """Handles requests to the OIDC well-known document.""" + return FakeResponse.json(payload=self.get_metadata()) + + def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse: + """Handles requests to the userinfo endpoint.""" + if access_token is None: + return FakeResponse(code=401) + user_info = self.get_userinfo(access_token) + if user_info is None: + return FakeResponse(code=401) + + return FakeResponse.json(payload=user_info) + + def _post_token_handler(self, params: Dict[str, List[str]]) -> IResponse: + """Handles requests to the token endpoint.""" + code = params.get("code", []) + + if len(code) != 1: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + grant = self.exchange_code(code=code[0]) + if grant is None: + return FakeResponse.json(code=400, payload={"error": "invalid_grant"}) + + return FakeResponse.json(payload=grant) -- cgit 1.4.1 From 6a6e1e8c0711939338f25d8d41d1e4d33d984949 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 28 Oct 2022 10:53:34 +0000 Subject: Fix room creation being rate limited too aggressively since Synapse v1.69.0. (#14314) * Introduce a test for the old behaviour which we want to restore * Reintroduce the old behaviour in a simpler way * Newsfile Signed-off-by: Olivier Wilkinson (reivilibre) * Use 1 credit instead of 2 for creating a room: be more lenient than before Notably, the UI in Element Web was still broken after restoring to prior behaviour. After discussion, we agreed that it would be sensible to increase the limit. Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/14314.bugfix | 1 + synapse/api/ratelimiting.py | 8 +++++- synapse/handlers/room.py | 16 ++++++++---- tests/rest/client/test_rooms.py | 54 ++++++++++++++++++++++++++++++++++++++--- 4 files changed, 70 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14314.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14314.bugfix b/changelog.d/14314.bugfix new file mode 100644 index 0000000000..8be47ee083 --- /dev/null +++ b/changelog.d/14314.bugfix @@ -0,0 +1 @@ +Fix room creation being rate limited too aggressively since Synapse v1.69.0. \ No newline at end of file diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 044c7d4926..511790c7c5 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -343,6 +343,7 @@ class RequestRatelimiter: requester: Requester, update: bool = True, is_admin_redaction: bool = False, + n_actions: int = 1, ) -> None: """Ratelimits requests. @@ -355,6 +356,8 @@ class RequestRatelimiter: is_admin_redaction: Whether this is a room admin/moderator redacting an event. If so then we may apply different ratelimits depending on config. + n_actions: Multiplier for the number of actions to apply to the + rate limiter at once. Raises: LimitExceededError if the request should be ratelimited @@ -383,7 +386,9 @@ class RequestRatelimiter: if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate # ratelimiter as to not have user_ids clash - await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) + await self.admin_redaction_ratelimiter.ratelimit( + requester, update=update, n_actions=n_actions + ) else: # Override rate and burst count per-user await self.request_ratelimiter.ratelimit( @@ -391,4 +396,5 @@ class RequestRatelimiter: rate_hz=messages_per_second, burst_count=burst_count, update=update, + n_actions=n_actions, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 638f54051a..d74b675adc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -559,7 +559,6 @@ class RoomCreationHandler: invite_list=[], initial_state=initial_state, creation_content=creation_content, - ratelimit=False, ) # Transfer membership events @@ -753,6 +752,10 @@ class RoomCreationHandler: ) if ratelimit: + # Rate limit once in advance, but don't rate limit the individual + # events in the room — room creation isn't atomic and it's very + # janky if half the events in the initial state don't make it because + # of rate limiting. await self.request_ratelimiter.ratelimit(requester) room_version_id = config.get( @@ -913,7 +916,6 @@ class RoomCreationHandler: room_alias=room_alias, power_level_content_override=power_level_content_override, creator_join_profile=creator_join_profile, - ratelimit=ratelimit, ) if "name" in config: @@ -1037,7 +1039,6 @@ class RoomCreationHandler: room_alias: Optional[RoomAlias] = None, power_level_content_override: Optional[JsonDict] = None, creator_join_profile: Optional[JsonDict] = None, - ratelimit: bool = True, ) -> Tuple[int, str, int]: """Sends the initial events into a new room. Sends the room creation, membership, and power level events into the room sequentially, then creates and batches up the @@ -1046,6 +1047,8 @@ class RoomCreationHandler: `power_level_content_override` doesn't apply when initial state has power level state event content. + Rate limiting should already have been applied by this point. + Returns: A tuple containing the stream ID, event ID and depth of the last event sent to the room. @@ -1144,7 +1147,7 @@ class RoomCreationHandler: creator.user, room_id, "join", - ratelimit=ratelimit, + ratelimit=False, content=creator_join_profile, new_room=True, prev_event_ids=[last_sent_event_id], @@ -1269,7 +1272,10 @@ class RoomCreationHandler: events_to_send.append((encryption_event, encryption_context)) last_event = await self.event_creation_handler.handle_new_client_event( - creator, events_to_send, ignore_shadow_ban=True + creator, + events_to_send, + ignore_shadow_ban=True, + ratelimit=False, ) assert last_event.internal_metadata.stream_ordering is not None return last_event.internal_metadata.stream_ordering, last_event.event_id, depth diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 716366eb90..1084d4ad9d 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -54,6 +54,7 @@ from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable from tests.test_utils.event_injection import create_event +from tests.unittest import override_config PATH_PREFIX = b"/_matrix/client/api/v1" @@ -871,6 +872,41 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) + def _create_basic_room(self) -> Tuple[int, object]: + """ + Tries to create a basic room and returns the response code. + """ + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + return channel.code, channel.json_body + + @override_config( + { + "rc_message": {"per_second": 0.2, "burst_count": 10}, + } + ) + def test_room_creation_ratelimiting(self) -> None: + """ + Regression test for #14312, where ratelimiting was made too strict. + Clients should be able to create 10 rooms in a row + without hitting rate limits, using default rate limit config. + (We override rate limiting config back to its default value.) + + To ensure we don't make ratelimiting too generous accidentally, + also check that we can't create an 11th room. + """ + + for _ in range(10): + code, json_body = self._create_basic_room() + self.assertEqual(code, HTTPStatus.OK, json_body) + + # The 6th room hits the rate limit. + code, json_body = self._create_basic_room() + self.assertEqual(code, HTTPStatus.TOO_MANY_REQUESTS, json_body) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -1390,10 +1426,22 @@ class RoomJoinRatelimitTestCase(RoomBase): ) def test_join_local_ratelimit(self) -> None: """Tests that local joins are actually rate-limited.""" - for _ in range(3): - self.helper.create_room_as(self.user_id) + # Create 4 rooms + room_ids = [ + self.helper.create_room_as(self.user_id, is_public=True) for _ in range(4) + ] + + joiner_user_id = self.register_user("joiner", "secret") + # Now make a new user try to join some of them. - self.helper.create_room_as(self.user_id, expect_code=429) + # The user can join 3 rooms + for room_id in room_ids[0:3]: + self.helper.join(room_id, joiner_user_id) + + # But the user cannot join a 4th room + self.helper.join( + room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS + ) @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} -- cgit 1.4.1 From cc3a52b33df72bb4230367536b924a6d1f510d36 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 31 Oct 2022 18:07:30 +0100 Subject: Support OIDC backchannel logouts (#11414) If configured an OIDC IdP can log a user's session out of Synapse when they log out of the identity provider. The IdP sends a request directly to Synapse (and must be configured with an endpoint) when a user logs out. --- changelog.d/11414.feature | 1 + docs/openid.md | 14 + docs/usage/configuration/config_documentation.md | 9 + synapse/config/oidc.py | 12 + synapse/handlers/oidc.py | 381 ++++++++++++++++++-- synapse/handlers/sso.py | 71 ++++ synapse/rest/synapse/client/oidc/__init__.py | 4 + .../client/oidc/backchannel_logout_resource.py | 35 ++ synapse/storage/databases/main/registration.py | 21 ++ tests/rest/client/test_auth.py | 390 +++++++++++++++++++-- tests/rest/client/utils.py | 55 ++- tests/server.py | 6 + tests/test_utils/oidc.py | 27 +- 13 files changed, 960 insertions(+), 66 deletions(-) create mode 100644 changelog.d/11414.feature create mode 100644 synapse/rest/synapse/client/oidc/backchannel_logout_resource.py (limited to 'tests/rest') diff --git a/changelog.d/11414.feature b/changelog.d/11414.feature new file mode 100644 index 0000000000..fc035e50a7 --- /dev/null +++ b/changelog.d/11414.feature @@ -0,0 +1 @@ +Support back-channel logouts from OpenID Connect providers. diff --git a/docs/openid.md b/docs/openid.md index 87ebea4c29..37c5eb244d 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -49,6 +49,13 @@ setting in your configuration file. See the [configuration manual](usage/configuration/config_documentation.md#oidc_providers) for some sample settings, as well as the text below for example configurations for specific providers. +## OIDC Back-Channel Logout + +Synapse supports receiving [OpenID Connect Back-Channel Logout](https://openid.net/specs/openid-connect-backchannel-1_0.html) notifications. + +This lets the OpenID Connect Provider notify Synapse when a user logs out, so that Synapse can end that user session. +This feature can be enabled by setting the `backchannel_logout_enabled` property to `true` in the provider configuration, and setting the following URL as destination for Back-Channel Logout notifications in your OpenID Connect Provider: `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout` + ## Sample configs Here are a few configs for providers that should work with Synapse. @@ -123,6 +130,9 @@ oidc_providers: [Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat. +Keycloak supports OIDC Back-Channel Logout, which sends logout notification to Synapse, so that Synapse users get logged out when they log out from Keycloak. +This can be optionally enabled by setting `backchannel_logout_enabled` to `true` in the Synapse configuration, and by setting the "Backchannel Logout URL" in Keycloak. + Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm. 1. Click `Clients` in the sidebar and click `Create` @@ -144,6 +154,8 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to | Client Protocol | `openid-connect` | | Access Type | `confidential` | | Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` | +| Backchannel Logout URL (optional) | `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout` | +| Backchannel Logout Session Required (optional) | `On` | 5. Click `Save` 6. On the Credentials tab, update the fields: @@ -167,7 +179,9 @@ oidc_providers: config: localpart_template: "{{ user.preferred_username }}" display_name_template: "{{ user.name }}" + backchannel_logout_enabled: true # Optional ``` + ### Auth0 [Auth0][auth0] is a hosted SaaS IdP solution. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 97fb505a5f..44358faf59 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3021,6 +3021,15 @@ Options for each entry include: which is set to the claims returned by the UserInfo Endpoint and/or in the ID Token. +* `backchannel_logout_enabled`: set to `true` to process OIDC Back-Channel Logout notifications. + Those notifications are expected to be received on `/_synapse/client/oidc/backchannel_logout`. + Defaults to `false`. + +* `backchannel_logout_ignore_sub`: by default, the OIDC Back-Channel Logout feature checks that the + `sub` claim matches the subject claim received during login. This check can be disabled by setting + this to `true`. Defaults to `false`. + + You might want to disable this if the `subject_claim` returned by the mapping provider is not `sub`. It is possible to configure Synapse to only allow logins if certain attributes match particular values in the OIDC userinfo. The requirements can be listed under diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 5418a332da..0bd83f4010 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -123,6 +123,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "userinfo_endpoint": {"type": "string"}, "jwks_uri": {"type": "string"}, "skip_verification": {"type": "boolean"}, + "backchannel_logout_enabled": {"type": "boolean"}, + "backchannel_logout_ignore_sub": {"type": "boolean"}, "user_profile_method": { "type": "string", "enum": ["auto", "userinfo_endpoint"], @@ -292,6 +294,10 @@ def _parse_oidc_config_dict( token_endpoint=oidc_config.get("token_endpoint"), userinfo_endpoint=oidc_config.get("userinfo_endpoint"), jwks_uri=oidc_config.get("jwks_uri"), + backchannel_logout_enabled=oidc_config.get("backchannel_logout_enabled", False), + backchannel_logout_ignore_sub=oidc_config.get( + "backchannel_logout_ignore_sub", False + ), skip_verification=oidc_config.get("skip_verification", False), user_profile_method=oidc_config.get("user_profile_method", "auto"), allow_existing_users=oidc_config.get("allow_existing_users", False), @@ -368,6 +374,12 @@ class OidcProviderConfig: # "openid" scope is used. jwks_uri: Optional[str] + # Whether Synapse should react to backchannel logouts + backchannel_logout_enabled: bool + + # Whether Synapse should ignore the `sub` claim in backchannel logouts or not. + backchannel_logout_ignore_sub: bool + # Whether to skip metadata verification skip_verification: bool diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 9759daf043..867973dcca 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -12,14 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import binascii import inspect +import json import logging -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Type, + TypeVar, + Union, +) from urllib.parse import urlencode, urlparse import attr +import unpaddedbase64 from authlib.common.security import generate_token -from authlib.jose import JsonWebToken, jwt +from authlib.jose import JsonWebToken, JWTClaims +from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri from authlib.oidc.core import CodeIDToken, UserInfo @@ -35,9 +49,12 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from twisted.web.http_headers import Headers +from synapse.api.errors import SynapseError from synapse.config import ConfigError from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig from synapse.handlers.sso import MappingException, UserAttributes +from synapse.http.server import finish_request +from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart @@ -88,6 +105,8 @@ class Token(TypedDict): #: there is no real point of doing this in our case. JWK = Dict[str, str] +C = TypeVar("C") + #: A JWK Set, as per RFC7517 sec 5. class JWKS(TypedDict): @@ -247,6 +266,80 @@ class OidcHandler: await oidc_provider.handle_oidc_callback(request, session_data, code) + async def handle_backchannel_logout(self, request: SynapseRequest) -> None: + """Handle an incoming request to /_synapse/client/oidc/backchannel_logout + + This extracts the logout_token from the request and tries to figure out + which OpenID Provider it is comming from. This works by matching the iss claim + with the issuer and the aud claim with the client_id. + + Since at this point we don't know who signed the JWT, we can't just + decode it using authlib since it will always verifies the signature. We + have to decode it manually without validating the signature. The actual JWT + verification is done in the `OidcProvider.handler_backchannel_logout` method, + once we figured out which provider sent the request. + + Args: + request: the incoming request from the browser. + """ + logout_token = parse_string(request, "logout_token") + if logout_token is None: + raise SynapseError(400, "Missing logout_token in request") + + # A JWT looks like this: + # header.payload.signature + # where all parts are encoded with urlsafe base64. + # The aud and iss claims we care about are in the payload part, which + # is a JSON object. + try: + # By destructuring the list after splitting, we ensure that we have + # exactly 3 segments + _, payload, _ = logout_token.split(".") + except ValueError: + raise SynapseError(400, "Invalid logout_token in request") + + try: + payload_bytes = unpaddedbase64.decode_base64(payload) + claims = json_decoder.decode(payload_bytes.decode("utf-8")) + except (json.JSONDecodeError, binascii.Error, UnicodeError): + raise SynapseError(400, "Invalid logout_token payload in request") + + try: + # Let's extract the iss and aud claims + iss = claims["iss"] + aud = claims["aud"] + # The aud claim can be either a string or a list of string. Here we + # normalize it as a list of strings. + if isinstance(aud, str): + aud = [aud] + + # Check that we have the right types for the aud and the iss claims + if not isinstance(iss, str) or not isinstance(aud, list): + raise TypeError() + for a in aud: + if not isinstance(a, str): + raise TypeError() + + # At this point we properly checked both claims types + issuer: str = iss + audience: List[str] = aud + except (TypeError, KeyError): + raise SynapseError(400, "Invalid issuer/audience in logout_token") + + # Now that we know the audience and the issuer, we can figure out from + # what provider it is coming from + oidc_provider: Optional[OidcProvider] = None + for provider in self._providers.values(): + if provider.issuer == issuer and provider.client_id in audience: + oidc_provider = provider + break + + if oidc_provider is None: + raise SynapseError(400, "Could not find the OP that issued this event") + + # Ask the provider to handle the logout request. + await oidc_provider.handle_backchannel_logout(request, logout_token) + class OidcError(Exception): """Used to catch errors when calling the token_endpoint""" @@ -342,6 +435,7 @@ class OidcProvider: self.idp_brand = provider.idp_brand self._sso_handler = hs.get_sso_handler() + self._device_handler = hs.get_device_handler() self._sso_handler.register_identity_provider(self) @@ -400,6 +494,41 @@ class OidcProvider: # If we're not using userinfo, we need a valid jwks to validate the ID token m.validate_jwks_uri() + if self._config.backchannel_logout_enabled: + if not m.get("backchannel_logout_supported", False): + logger.warning( + "OIDC Back-Channel Logout is enabled for issuer %r" + "but it does not advertise support for it", + self.issuer, + ) + + elif not m.get("backchannel_logout_session_supported", False): + logger.warning( + "OIDC Back-Channel Logout is enabled and supported " + "by issuer %r but it might not send a session ID with " + "logout tokens, which is required for the logouts to work", + self.issuer, + ) + + if not self._config.backchannel_logout_ignore_sub: + # If OIDC backchannel logouts are enabled, the provider mapping provider + # should use the `sub` claim. We verify that by mapping a dumb user and + # see if we get back the sub claim + user = UserInfo({"sub": "thisisasubject"}) + try: + subject = self._user_mapping_provider.get_remote_user_id(user) + if subject != user["sub"]: + raise ValueError("Unexpected subject") + except Exception: + logger.warning( + f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} " + "but it looks like the configured `user_mapping_provider` " + "does not use the `sub` claim as subject. If it is the case, " + "and you want Synapse to ignore the `sub` claim in OIDC " + "Back-Channel Logouts, set `backchannel_logout_ignore_sub` " + "to `true` in the issuer config." + ) + @property def _uses_userinfo(self) -> bool: """Returns True if the ``userinfo_endpoint`` should be used. @@ -415,6 +544,16 @@ class OidcProvider: or self._user_profile_method == "userinfo_endpoint" ) + @property + def issuer(self) -> str: + """The issuer identifying this provider.""" + return self._config.issuer + + @property + def client_id(self) -> str: + """The client_id used when interacting with this provider.""" + return self._config.client_id + async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata: """Return the provider metadata. @@ -662,6 +801,59 @@ class OidcProvider: return UserInfo(resp) + async def _verify_jwt( + self, + alg_values: List[str], + token: str, + claims_cls: Type[C], + claims_options: Optional[dict] = None, + claims_params: Optional[dict] = None, + ) -> C: + """Decode and validate a JWT, re-fetching the JWKS as needed. + + Args: + alg_values: list of `alg` values allowed when verifying the JWT. + token: the JWT. + claims_cls: the JWTClaims class to use to validate the claims. + claims_options: dict of options passed to the `claims_cls` constructor. + claims_params: dict of params passed to the `claims_cls` constructor. + + Returns: + The decoded claims in the JWT. + """ + jwt = JsonWebToken(alg_values) + + logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token) + + # Try to decode the keys in cache first, then retry by forcing the keys + # to be reloaded + jwk_set = await self.load_jwks() + try: + claims = jwt.decode( + token, + key=jwk_set, + claims_cls=claims_cls, + claims_options=claims_options, + claims_params=claims_params, + ) + except ValueError: + logger.info("Reloading JWKS after decode error") + jwk_set = await self.load_jwks(force=True) # try reloading the jwks + claims = jwt.decode( + token, + key=jwk_set, + claims_cls=claims_cls, + claims_options=claims_options, + claims_params=claims_params, + ) + + logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims) + + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew + return claims + async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: """Return an instance of UserInfo from token's ``id_token``. @@ -675,13 +867,13 @@ class OidcProvider: The decoded claims in the ID token. """ id_token = token.get("id_token") - logger.debug("Attempting to decode JWT id_token %r", id_token) # That has been theoritically been checked by the caller, so even though # assertion are not enabled in production, it is mainly here to appease mypy assert id_token is not None metadata = await self.load_metadata() + claims_params = { "nonce": nonce, "client_id": self._client_auth.client_id, @@ -691,38 +883,17 @@ class OidcProvider: # in the `id_token` that we can check against. claims_params["access_token"] = token["access_token"] - alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - jwt = JsonWebToken(alg_values) - - claim_options = {"iss": {"values": [metadata["issuer"]]}} + claims_options = {"iss": {"values": [metadata["issuer"]]}} - # Try to decode the keys in cache first, then retry by forcing the keys - # to be reloaded - jwk_set = await self.load_jwks() - try: - claims = jwt.decode( - id_token, - key=jwk_set, - claims_cls=CodeIDToken, - claims_options=claim_options, - claims_params=claims_params, - ) - except ValueError: - logger.info("Reloading JWKS after decode error") - jwk_set = await self.load_jwks(force=True) # try reloading the jwks - claims = jwt.decode( - id_token, - key=jwk_set, - claims_cls=CodeIDToken, - claims_options=claim_options, - claims_params=claims_params, - ) - - logger.debug("Decoded id_token JWT %r; validating", claims) + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - claims.validate( - now=self._clock.time(), leeway=120 - ) # allows 2 min of clock skew + claims = await self._verify_jwt( + alg_values=alg_values, + token=id_token, + claims_cls=CodeIDToken, + claims_options=claims_options, + claims_params=claims_params, + ) return claims @@ -1043,6 +1214,146 @@ class OidcProvider: # to be strings. return str(remote_user_id) + async def handle_backchannel_logout( + self, request: SynapseRequest, logout_token: str + ) -> None: + """Handle an incoming request to /_synapse/client/oidc/backchannel_logout + + The OIDC Provider posts a logout token to this endpoint when a user + session ends. That token is a JWT signed with the same keys as + ID tokens. The OpenID Connect Back-Channel Logout draft explains how to + validate the JWT and figure out what session to end. + + Args: + request: The request to respond to + logout_token: The logout token (a JWT) extracted from the request body + """ + # Back-Channel Logout can be disabled in the config, hence this check. + # This is not that important for now since Synapse is registered + # manually to the OP, so not specifying the backchannel-logout URI is + # as effective than disabling it here. It might make more sense if we + # support dynamic registration in Synapse at some point. + if not self._config.backchannel_logout_enabled: + logger.warning( + f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config" + ) + + # TODO: this responds with a 400 status code, which is what the OIDC + # Back-Channel Logout spec expects, but spec also suggests answering with + # a JSON object, with the `error` and `error_description` fields set, which + # we are not doing here. + # See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse + raise SynapseError( + 400, "OpenID Connect Back-Channel Logout is disabled for this provider" + ) + + metadata = await self.load_metadata() + + # As per OIDC Back-Channel Logout 1.0 sec. 2.4: + # A Logout Token MUST be signed and MAY also be encrypted. The same + # keys are used to sign and encrypt Logout Tokens as are used for ID + # Tokens. If the Logout Token is encrypted, it SHOULD replicate the + # iss (issuer) claim in the JWT Header Parameters, as specified in + # Section 5.3 of [JWT]. + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) + + # As per sec. 2.6: + # 3. Validate the iss, aud, and iat Claims in the same way they are + # validated in ID Tokens. + # Which means the audience should contain Synapse's client_id and the + # issuer should be the IdP issuer + claims_options = { + "iss": {"values": [metadata["issuer"]]}, + "aud": {"values": [self.client_id]}, + } + + try: + claims = await self._verify_jwt( + alg_values=alg_values, + token=logout_token, + claims_cls=LogoutToken, + claims_options=claims_options, + ) + except JoseError: + logger.exception("Invalid logout_token") + raise SynapseError(400, "Invalid logout_token") + + # As per sec. 2.6: + # 4. Verify that the Logout Token contains a sub Claim, a sid Claim, + # or both. + # 5. Verify that the Logout Token contains an events Claim whose + # value is JSON object containing the member name + # http://schemas.openid.net/event/backchannel-logout. + # 6. Verify that the Logout Token does not contain a nonce Claim. + # This is all verified by the LogoutToken claims class, so at this + # point the `sid` claim exists and is a string. + sid: str = claims.get("sid") + + # If the `sub` claim was included in the logout token, we check that it matches + # that it matches the right user. We can have cases where the `sub` claim is not + # the ID saved in database, so we let admins disable this check in config. + sub: Optional[str] = claims.get("sub") + expected_user_id: Optional[str] = None + if sub is not None and not self._config.backchannel_logout_ignore_sub: + expected_user_id = await self._store.get_user_by_external_id( + self.idp_id, sub + ) + + # Invalidate any running user-mapping sessions, in-flight login tokens and + # active devices + await self._sso_handler.revoke_sessions_for_provider_session_id( + auth_provider_id=self.idp_id, + auth_provider_session_id=sid, + expected_user_id=expected_user_id, + ) + + request.setResponseCode(200) + request.setHeader(b"Cache-Control", b"no-cache, no-store") + request.setHeader(b"Pragma", b"no-cache") + finish_request(request) + + +class LogoutToken(JWTClaims): + """ + Holds and verify claims of a logout token, as per + https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken + """ + + REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"] + + def validate(self, now: Optional[int] = None, leeway: int = 0) -> None: + """Validate everything in claims payload.""" + super().validate(now, leeway) + self.validate_sid() + self.validate_events() + self.validate_nonce() + + def validate_sid(self) -> None: + """Ensure the sid claim is present""" + sid = self.get("sid") + if not sid: + raise MissingClaimError("sid") + + if not isinstance(sid, str): + raise InvalidClaimError("sid") + + def validate_nonce(self) -> None: + """Ensure the nonce claim is absent""" + if "nonce" in self: + raise InvalidClaimError("nonce") + + def validate_events(self) -> None: + """Ensure the events claim is present and with the right value""" + events = self.get("events") + if not events: + raise MissingClaimError("events") + + if not isinstance(events, dict): + raise InvalidClaimError("events") + + if "http://schemas.openid.net/event/backchannel-logout" not in events: + raise InvalidClaimError("events") + # number of seconds a newly-generated client secret should be valid for CLIENT_SECRET_VALIDITY_SECONDS = 3600 @@ -1112,6 +1423,7 @@ class JwtClientSecret: logger.info( "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload ) + jwt = JsonWebToken(header["alg"]) self._cached_secret = jwt.encode(header, payload, self._key.key) self._cached_secret_replacement_time = ( expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS @@ -1126,9 +1438,6 @@ class UserAttributeDict(TypedDict): emails: List[str] -C = TypeVar("C") - - class OidcMappingProvider(Generic[C]): """A mapping provider maps a UserInfo object to user attributes. diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 5943f08e91..749d7e93b0 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -191,6 +191,7 @@ class SsoHandler: self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() self._error_template = hs.config.sso.sso_error_template self._bad_user_template = hs.config.sso.sso_auth_bad_user_template self._profile_handler = hs.get_profile_handler() @@ -1026,6 +1027,76 @@ class SsoHandler: return True + async def revoke_sessions_for_provider_session_id( + self, + auth_provider_id: str, + auth_provider_session_id: str, + expected_user_id: Optional[str] = None, + ) -> None: + """Revoke any devices and in-flight logins tied to a provider session. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + auth_provider_session_id: The session ID from the provider to logout + expected_user_id: The user we're expecting to logout. If set, it will ignore + sessions belonging to other users and log an error. + """ + # Invalidate any running user-mapping sessions + to_delete = [] + for session_id, session in self._username_mapping_sessions.items(): + if ( + session.auth_provider_id == auth_provider_id + and session.auth_provider_session_id == auth_provider_session_id + ): + to_delete.append(session_id) + + for session_id in to_delete: + logger.info("Revoking mapping session %s", session_id) + del self._username_mapping_sessions[session_id] + + # Invalidate any in-flight login tokens + await self._store.invalidate_login_tokens_by_session_id( + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + # Fetch any device(s) in the store associated with the session ID. + devices = await self._store.get_devices_by_auth_provider_session_id( + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + # We have no guarantee that all the devices of that session are for the same + # `user_id`. Hence, we have to iterate over the list of devices and log them out + # one by one. + for device in devices: + user_id = device["user_id"] + device_id = device["device_id"] + + # If the user_id associated with that device/session is not the one we got + # out of the `sub` claim, skip that device and show log an error. + if expected_user_id is not None and user_id != expected_user_id: + logger.error( + "Received a logout notification from SSO provider " + f"{auth_provider_id!r} for the user {expected_user_id!r}, but with " + f"a session ID ({auth_provider_session_id!r}) which belongs to " + f"{user_id!r}. This may happen when the SSO provider user mapper " + "uses something else than the standard attribute as mapping ID. " + "For OIDC providers, set `backchannel_logout_ignore_sub` to `true` " + "in the provider config if that is the case." + ) + continue + + logger.info( + "Logging out %r (device %r) via SSO (%r) logout notification (session %r).", + user_id, + device_id, + auth_provider_id, + auth_provider_session_id, + ) + await self._device_handler.delete_devices(user_id, [device_id]) + def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: """Extract the session ID from the cookie diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py index 81fec39659..e4b28ce3df 100644 --- a/synapse/rest/synapse/client/oidc/__init__.py +++ b/synapse/rest/synapse/client/oidc/__init__.py @@ -17,6 +17,9 @@ from typing import TYPE_CHECKING from twisted.web.resource import Resource +from synapse.rest.synapse.client.oidc.backchannel_logout_resource import ( + OIDCBackchannelLogoutResource, +) from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource if TYPE_CHECKING: @@ -29,6 +32,7 @@ class OIDCResource(Resource): def __init__(self, hs: "HomeServer"): Resource.__init__(self) self.putChild(b"callback", OIDCCallbackResource(hs)) + self.putChild(b"backchannel_logout", OIDCBackchannelLogoutResource(hs)) __all__ = ["OIDCResource"] diff --git a/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py new file mode 100644 index 0000000000..e07e76855a --- /dev/null +++ b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py @@ -0,0 +1,35 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from synapse.http.server import DirectServeJsonResource +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class OIDCBackchannelLogoutResource(DirectServeJsonResource): + isLeaf = 1 + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._oidc_handler = hs.get_oidc_handler() + + async def _async_render_POST(self, request: SynapseRequest) -> None: + await self._oidc_handler.handle_backchannel_logout(request) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 0255295317..5167089e03 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1920,6 +1920,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): self._clock.time_msec(), ) + async def invalidate_login_tokens_by_session_id( + self, auth_provider_id: str, auth_provider_session_id: str + ) -> None: + """Invalidate login tokens with the given IdP session ID. + + Args: + auth_provider_id: The SSO Identity Provider that the user authenticated with + to get this token + auth_provider_session_id: The session ID advertised by the SSO Identity + Provider + """ + await self.db_pool.simple_update( + table="login_tokens", + keyvalues={ + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + updatevalues={"used_ts": self._clock.time_msec()}, + desc="invalidate_login_tokens_by_session_id", + ) + @cached() async def is_guest(self, user_id: str) -> bool: res = await self.db_pool.simple_select_one_onecol( diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index ebf653d018..847294dc8e 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re from http import HTTPStatus from typing import Any, Dict, List, Optional, Tuple, Union @@ -21,7 +22,7 @@ from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType -from synapse.api.errors import Codes +from synapse.api.errors import Codes, SynapseError from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree @@ -32,8 +33,8 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC -from tests.rest.client.utils import TEST_OIDC_CONFIG -from tests.server import FakeChannel +from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER +from tests.server import FakeChannel, make_request from tests.unittest import override_config, skip_unless @@ -638,19 +639,6 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": refresh_token}, ) - def is_access_token_valid(self, access_token: str) -> bool: - """ - Checks whether an access token is valid, returning whether it is or not. - """ - code = self.make_request( - "GET", "/_matrix/client/v3/account/whoami", access_token=access_token - ).code - - # Either 200 or 401 is what we get back; anything else is a bug. - assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED} - - return code == HTTPStatus.OK - def test_login_issue_refresh_token(self) -> None: """ A login response should include a refresh_token only if asked. @@ -847,29 +835,37 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.reactor.advance(59.0) # Both tokens should still be valid. - self.assertTrue(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami(refreshable_access_token, expect_code=HTTPStatus.OK) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 61 s (just past 1 minute, the time of expiry) self.reactor.advance(2.0) # Only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 599 s (just shy of 10 minutes, the time of expiry) self.reactor.advance(599.0 - 61.0) # It's still the case that only the non-refreshable token is still valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK) # Advance to 601 s (just past 10 minutes, the time of expiry) self.reactor.advance(2.0) # Now neither token is valid. - self.assertFalse(self.is_access_token_valid(refreshable_access_token)) - self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token)) + self.helper.whoami( + refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) + self.helper.whoami( + nonrefreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED + ) @override_config( {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} @@ -1165,3 +1161,349 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # and no refresh token self.assertEqual(_table_length("access_tokens"), 0) self.assertEqual(_table_length("refresh_tokens"), 0) + + +def oidc_config( + id: str, with_localpart_template: bool, **kwargs: Any +) -> Dict[str, Any]: + """Sample OIDC provider config used in backchannel logout tests. + + Args: + id: IDP ID for this provider + with_localpart_template: Set to `true` to have a default localpart_template in + the `user_mapping_provider` config and skip the user mapping session + **kwargs: rest of the config + + Returns: + A dict suitable for the `oidc_config` or the `oidc_providers[]` parts of + the HS config + """ + config: Dict[str, Any] = { + "idp_id": id, + "idp_name": id, + "issuer": TEST_OIDC_ISSUER, + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["openid"], + } + + if with_localpart_template: + config["user_mapping_provider"] = { + "config": {"localpart_template": "{{ user.sub }}"} + } + else: + config["user_mapping_provider"] = {"config": {}} + + config.update(kwargs) + + return config + + +@skip_unless(HAS_OIDC, "Requires OIDC") +class OidcBackchannelLogoutTests(unittest.HomeserverTestCase): + servlets = [ + account.register_servlets, + login.register_servlets, + ] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + + # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns + # False, so synapse will see the requested uri as http://..., so using http in + # the public_baseurl stops Synapse trying to redirect to https. + config["public_baseurl"] = "http://synapse.test" + + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + resource_dict = super().create_resource_dict() + resource_dict.update(build_synapse_client_resource_tree(self.hs)) + return resource_dict + + def submit_logout_token(self, logout_token: str) -> FakeChannel: + return self.make_request( + "POST", + "/_synapse/client/oidc/backchannel_logout", + content=f"logout_token={logout_token}", + content_is_form=True, + ) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_simple_logout(self) -> None: + """ + Receiving a logout token should logout the user + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, first_grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + first_access_token: str = login_resp["access_token"] + self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK) + + login_resp, second_grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + second_access_token: str = login_resp["access_token"] + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + self.assertNotEqual(first_grant.sid, second_grant.sid) + self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"]) + + # Logging out of the first session + logout_token = fake_oidc_server.generate_logout_token(first_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED) + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # Logging out of the second session + logout_token = fake_oidc_server.generate_logout_token(second_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_logout_during_login(self) -> None: + """ + It should revoke login tokens when receiving a logout token + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + # Get an authentication, and logout before submitting the logout token + client_redirect_url = "https://x" + userinfo = {"sub": user} + channel, grant = self.helper.auth_via_oidc( + fake_oidc_server, + userinfo, + client_redirect_url, + with_sid=True, + ) + + # expect a confirmation page + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.text_body, + ) + assert m, channel.text_body + login_token = m.group(1) + + # Submit a logout + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + # Now try to exchange the login token + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + # It should have failed + self.assertEqual(channel.code, 403) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=False, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_logout_during_mapping(self) -> None: + """ + It should stop ongoing user mapping session when receiving a logout token + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + # Get an authentication, and logout before submitting the logout token + client_redirect_url = "https://x" + userinfo = {"sub": user} + channel, grant = self.helper.auth_via_oidc( + fake_oidc_server, + userinfo, + client_redirect_url, + with_sid=True, + ) + + # Expect a user mapping page + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) + + # We should have a user_mapping_session cookie + cookie_headers = channel.headers.getRawHeaders("Set-Cookie") + assert cookie_headers + cookies: Dict[str, str] = {} + for h in cookie_headers: + key, value = h.split(";")[0].split("=", maxsplit=1) + cookies[key] = value + + user_mapping_session_id = cookies["username_mapping_session"] + + # Getting that session should not raise + session = self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id) + self.assertIsNotNone(session) + + # Submit a logout + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + # Now it should raise + with self.assertRaises(SynapseError): + self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=False, + ) + ] + } + ) + def test_disabled(self) -> None: + """ + Receiving a logout token should do nothing if it is disabled in the config + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=True + ) + access_token: str = login_resp["access_token"] + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + # Logging out shouldn't work + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 400) + + # And the token should still be valid + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + @override_config( + { + "oidc_providers": [ + oidc_config( + id="oidc", + with_localpart_template=True, + backchannel_logout_enabled=True, + ) + ] + } + ) + def test_no_sid(self) -> None: + """ + Receiving a logout token without `sid` during the login should do nothing + """ + fake_oidc_server = self.helper.fake_oidc_server() + user = "john" + + login_resp, grant = self.helper.login_via_oidc( + fake_oidc_server, user, with_sid=False + ) + access_token: str = login_resp["access_token"] + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + # Logging out shouldn't work + logout_token = fake_oidc_server.generate_logout_token(grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 400) + + # And the token should still be valid + self.helper.whoami(access_token, expect_code=HTTPStatus.OK) + + @override_config( + { + "oidc_providers": [ + oidc_config( + "first", + issuer="https://first-issuer.com/", + with_localpart_template=True, + backchannel_logout_enabled=True, + ), + oidc_config( + "second", + issuer="https://second-issuer.com/", + with_localpart_template=True, + backchannel_logout_enabled=True, + ), + ] + } + ) + def test_multiple_providers(self) -> None: + """ + It should be able to distinguish login tokens from two different IdPs + """ + first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/") + second_server = self.helper.fake_oidc_server( + issuer="https://second-issuer.com/" + ) + user = "john" + + login_resp, first_grant = self.helper.login_via_oidc( + first_server, user, with_sid=True, idp_id="oidc-first" + ) + first_access_token: str = login_resp["access_token"] + self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK) + + login_resp, second_grant = self.helper.login_via_oidc( + second_server, user, with_sid=True, idp_id="oidc-second" + ) + second_access_token: str = login_resp["access_token"] + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # `sid` in the fake providers are generated by a counter, so the first grant of + # each provider should give the same SID + self.assertEqual(first_grant.sid, second_grant.sid) + self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"]) + + # Logging out of the first session + logout_token = first_server.generate_logout_token(first_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED) + self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK) + + # Logging out of the second session + logout_token = second_server.generate_logout_token(second_grant) + channel = self.submit_logout_token(logout_token) + self.assertEqual(channel.code, 200) + + self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 967d229223..706399fae5 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -553,6 +553,34 @@ class RestHelper: return channel.json_body + def whoami( + self, + access_token: str, + expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK, + ) -> JsonDict: + """Perform a 'whoami' request, which can be a quick way to check for access + token validity + + Args: + access_token: The user token to use during the request + expect_code: The return code to expect from attempting the whoami request + """ + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + "account/whoami", + access_token=access_token, + ) + + assert channel.code == expect_code, "Exepcted: %d, got %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: """Create a ``FakeOidcServer``. @@ -572,6 +600,7 @@ class RestHelper: fake_server: FakeOidcServer, remote_user_id: str, with_sid: bool = False, + idp_id: Optional[str] = None, expected_status: int = 200, ) -> Tuple[JsonDict, FakeAuthorizationGrant]: """Log in (as a new user) via OIDC @@ -588,7 +617,11 @@ class RestHelper: client_redirect_url = "https://x" userinfo = {"sub": remote_user_id} channel, grant = self.auth_via_oidc( - fake_server, userinfo, client_redirect_url, with_sid=with_sid + fake_server, + userinfo, + client_redirect_url, + with_sid=with_sid, + idp_id=idp_id, ) # expect a confirmation page @@ -623,6 +656,7 @@ class RestHelper: client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, with_sid: bool = False, + idp_id: Optional[str] = None, ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Perform an OIDC authentication flow via a mock OIDC provider. @@ -648,6 +682,7 @@ class RestHelper: ui_auth_session_id: if set, we will perform a UI Auth flow. The session id of the UI auth. with_sid: if True, generates a random `sid` (OIDC session ID) + idp_id: if set, explicitely chooses one specific IDP Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. @@ -665,7 +700,9 @@ class RestHelper: oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) else: # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + oauth_uri = self.initiate_sso_login( + client_redirect_url, cookies, idp_id=idp_id + ) # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" @@ -742,7 +779,10 @@ class RestHelper: return channel, grant def initiate_sso_login( - self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] + self, + client_redirect_url: Optional[str], + cookies: MutableMapping[str, str], + idp_id: Optional[str] = None, ) -> str: """Make a request to the login-via-sso redirect endpoint, and return the target @@ -753,6 +793,7 @@ class RestHelper: client_redirect_url: the client redirect URL to pass to the login redirect endpoint cookies: any cookies returned will be added to this dict + idp_id: if set, explicitely chooses one specific IDP Returns: the URI that the client gets redirected to (ie, the SSO server) @@ -761,6 +802,12 @@ class RestHelper: if client_redirect_url: params["redirectUrl"] = client_redirect_url + uri = "/_matrix/client/r0/login/sso/redirect" + if idp_id is not None: + uri = f"{uri}/{idp_id}" + + uri = f"{uri}?{urllib.parse.urlencode(params)}" + # hit the redirect url (which should redirect back to the redirect url. This # is the easiest way of figuring out what the Host header ought to be set to # to keep Synapse happy. @@ -768,7 +815,7 @@ class RestHelper: self.hs.get_reactor(), self.site, "GET", - "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), + uri, ) assert channel.code == 302 diff --git a/tests/server.py b/tests/server.py index 8b1d186219..b1730fcc8d 100644 --- a/tests/server.py +++ b/tests/server.py @@ -362,6 +362,12 @@ def make_request( # Twisted expects to be at the end of the content when parsing the request. req.content.seek(0, SEEK_END) + # Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded + # bodies if the Content-Length header is missing + req.requestHeaders.addRawHeader( + b"Content-Length", str(len(content)).encode("ascii") + ) + if access_token: req.requestHeaders.addRawHeader( b"Authorization", b"Bearer " + access_token.encode("ascii") diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py index de134bbc89..1461d23ee8 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py @@ -51,6 +51,8 @@ class FakeOidcServer: get_userinfo_handler: Mock post_token_handler: Mock + sid_counter: int = 0 + def __init__(self, clock: Clock, issuer: str): from authlib.jose import ECKey, KeySet @@ -146,7 +148,7 @@ class FakeOidcServer: return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: - now = self._clock.time() + now = int(self._clock.time()) id_token = { **grant.userinfo, "iss": self.issuer, @@ -166,6 +168,26 @@ class FakeOidcServer: return self._sign(id_token) + def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str: + now = int(self._clock.time()) + logout_token = { + "iss": self.issuer, + "aud": grant.client_id, + "iat": now, + "jti": random_string(10), + "events": { + "http://schemas.openid.net/event/backchannel-logout": {}, + }, + } + + if grant.sid is not None: + logout_token["sid"] = grant.sid + + if "sub" in grant.userinfo: + logout_token["sub"] = grant.userinfo["sub"] + + return self._sign(logout_token) + def id_token_override(self, overrides: dict): """Temporarily patch the ID token generated by the token endpoint.""" return patch.object(self, "_id_token_overrides", overrides) @@ -183,7 +205,8 @@ class FakeOidcServer: code = random_string(10) sid = None if with_sid: - sid = random_string(10) + sid = str(self.sid_counter) + self.sid_counter += 1 grant = FakeAuthorizationGrant( userinfo=userinfo, -- cgit 1.4.1 From dbfc9b803ee32f7b31c2b5ccbc53a1bfcaa95983 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 31 Oct 2022 20:31:43 +0000 Subject: Fix dehydrated device REST checks (#14336) --- changelog.d/14336.bugfix | 1 + synapse/rest/client/devices.py | 5 ++--- tests/rest/client/test_devices.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14336.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14336.bugfix b/changelog.d/14336.bugfix new file mode 100644 index 0000000000..d44ff1bbc7 --- /dev/null +++ b/changelog.d/14336.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70 where clients were unable to PUT new [dehydrated devices](https://github.com/matrix-org/matrix-spec-proposals/pull/2697). diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 90828c95c4..8f3cbd4ea2 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -231,7 +231,7 @@ class DehydratedDeviceServlet(RestServlet): } } - PUT /org.matrix.msc2697/dehydrated_device + PUT /org.matrix.msc2697.v2/dehydrated_device Content-Type: application/json { @@ -271,7 +271,6 @@ class DehydratedDeviceServlet(RestServlet): raise errors.NotFoundError("No dehydrated device available") class PutBody(RequestBodyModel): - device_id: StrictStr device_data: DehydratedDeviceDataModel initial_device_display_name: Optional[StrictStr] @@ -281,7 +280,7 @@ class DehydratedDeviceServlet(RestServlet): device_id = await self.device_handler.store_dehydrated_device( requester.user.to_string(), - submission.device_data, + submission.device_data.dict(), submission.initial_device_display_name, ) return 200, {"device_id": device_id} diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py index aa98222434..d80eea17d3 100644 --- a/tests/rest/client/test_devices.py +++ b/tests/rest/client/test_devices.py @@ -200,3 +200,37 @@ class DevicesTestCase(unittest.HomeserverTestCase): self.reactor.advance(43200) self.get_success(self.handler.get_device(user_id, "abc")) self.get_failure(self.handler.get_device(user_id, "def"), NotFoundError) + + +class DehydratedDeviceTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + devices.register_servlets, + ] + + def test_PUT(self) -> None: + """Sanity-check that we can PUT a dehydrated device. + + Detects https://github.com/matrix-org/synapse/issues/14334. + """ + alice = self.register_user("alice", "correcthorse") + token = self.login(alice, "correcthorse") + + # Have alice update their device list + channel = self.make_request( + "PUT", + "_matrix/client/unstable/org.matrix.msc2697.v2/dehydrated_device", + { + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device", + } + }, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + device_id = channel.json_body.get("device_id") + self.assertIsInstance(device_id, str) -- cgit 1.4.1 From 86c5a710d8b4212f8a8a668d7d4a79c0bb371508 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 3 Nov 2022 16:21:31 +0000 Subject: Implement MSC3912: Relation-based redactions (#14260) Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14260.feature | 1 + synapse/api/constants.py | 2 + synapse/config/experimental.py | 3 + synapse/handlers/message.py | 47 ++++- synapse/handlers/relations.py | 56 +++++- synapse/rest/client/room.py | 57 ++++-- synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/relations.py | 36 ++++ tests/rest/client/test_redactions.py | 273 +++++++++++++++++++++++++++- tests/rest/client/utils.py | 37 ++++ 10 files changed, 486 insertions(+), 28 deletions(-) create mode 100644 changelog.d/14260.feature (limited to 'tests/rest') diff --git a/changelog.d/14260.feature b/changelog.d/14260.feature new file mode 100644 index 0000000000..102dc7b3e0 --- /dev/null +++ b/changelog.d/14260.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3912](https://github.com/matrix-org/matrix-spec-proposals/pull/3912): Relation-based redactions. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 44c5ffc6a5..bc04a0755b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -125,6 +125,8 @@ class EventTypes: MSC2716_BATCH: Final = "org.matrix.msc2716.batch" MSC2716_MARKER: Final = "org.matrix.msc2716.marker" + Reaction: Final = "m.reaction" + class ToDeviceEventTypes: RoomKeyRequest: Final = "m.room_key_request" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d9bdd66d55..d4b71d1673 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -128,3 +128,6 @@ class ExperimentalConfig(Config): self.msc3886_endpoint: Optional[str] = experimental.get( "msc3886_endpoint", None ) + + # MSC3912: Relation-based redactions. + self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 468900a07f..4cf593cfdc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -877,6 +877,36 @@ class EventCreationHandler: return prev_event return None + async def get_event_from_transaction( + self, + requester: Requester, + txn_id: str, + room_id: str, + ) -> Optional[EventBase]: + """For the given transaction ID and room ID, check if there is a matching event. + If so, fetch it and return it. + + Args: + requester: The requester making the request in the context of which we want + to fetch the event. + txn_id: The transaction ID. + room_id: The room ID. + + Returns: + An event if one could be found, None otherwise. + """ + if requester.access_token_id: + existing_event_id = await self.store.get_event_id_from_transaction_id( + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, + ) + if existing_event_id: + return await self.store.get_event(existing_event_id) + + return None + async def create_and_send_nonmember_event( self, requester: Requester, @@ -956,18 +986,17 @@ class EventCreationHandler: # extremities to pile up, which in turn leads to state resolution # taking longer. async with self.limiter.queue(event_dict["room_id"]): - if txn_id and requester.access_token_id: - existing_event_id = await self.store.get_event_id_from_transaction_id( - event_dict["room_id"], - requester.user.to_string(), - requester.access_token_id, - txn_id, + if txn_id: + event = await self.get_event_from_transaction( + requester, txn_id, event_dict["room_id"] ) - if existing_event_id: - event = await self.store.get_event(existing_event_id) + if event: # we know it was persisted, so must have a stream ordering assert event.internal_metadata.stream_ordering - return event, event.internal_metadata.stream_ordering + return ( + event, + event.internal_metadata.stream_ordering, + ) event, context = await self.create_event( requester, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 0a0c6d938e..8e71dda970 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tup import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace @@ -75,6 +75,7 @@ class RelationsHandler: self._clock = hs.get_clock() self._event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._event_creation_handler = hs.get_event_creation_handler() async def get_relations( self, @@ -205,6 +206,59 @@ class RelationsHandler: return related_events, next_token + async def redact_events_related_to( + self, + requester: Requester, + event_id: str, + initial_redaction_event: EventBase, + relation_types: List[str], + ) -> None: + """Redacts all events related to the given event ID with one of the given + relation types. + + This method is expected to be called when redacting the event referred to by + the given event ID. + + If an event cannot be redacted (e.g. because of insufficient permissions), log + the error and try to redact the next one. + + Args: + requester: The requester to redact events on behalf of. + event_id: The event IDs to look and redact relations of. + initial_redaction_event: The redaction for the event referred to by + event_id. + relation_types: The types of relations to look for. + + Raises: + ShadowBanError if the requester is shadow-banned + """ + related_event_ids = ( + await self._main_store.get_all_relations_for_event_with_types( + event_id, relation_types + ) + ) + + for related_event_id in related_event_ids: + try: + await self._event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": initial_redaction_event.content, + "room_id": initial_redaction_event.room_id, + "sender": requester.user.to_string(), + "redacts": related_event_id, + }, + ratelimit=False, + ) + except SynapseError as e: + logger.warning( + "Failed to redact event %s (related to event %s): %s", + related_event_id, + event_id, + e.msg, + ) + async def get_annotations_for_event( self, event_id: str, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 01e5079963..91cb791139 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -52,6 +52,7 @@ from synapse.http.servlet import ( from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.state import StateFilter @@ -1029,6 +1030,8 @@ class RoomRedactEventRestServlet(TransactionRestServlet): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() + self._relation_handler = hs.get_relations_handler() + self._msc3912_enabled = hs.config.experimental.msc3912_enabled def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" @@ -1045,20 +1048,46 @@ class RoomRedactEventRestServlet(TransactionRestServlet): content = parse_json_object_from_request(request) try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) + with_relations = None + if self._msc3912_enabled and "org.matrix.msc3912.with_relations" in content: + with_relations = content["org.matrix.msc3912.with_relations"] + del content["org.matrix.msc3912.with_relations"] + + # Check if there's an existing event for this transaction now (even though + # create_and_send_nonmember_event also does it) because, if there's one, + # then we want to skip the call to redact_events_related_to. + event = None + if txn_id: + event = await self.event_creation_handler.get_event_from_transaction( + requester, txn_id, room_id + ) + + if event is None: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + + if with_relations: + run_as_background_process( + "redact_related_events", + self._relation_handler.redact_events_related_to, + requester=requester, + event_id=event_id, + initial_redaction_event=event, + relation_types=with_relations, + ) + event_id = event.event_id except ShadowBanError: event_id = "$" + random_string(43) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 9b1b72c68a..180a11ef88 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -119,6 +119,8 @@ class VersionsRestServlet(RestServlet): # Adds support for simple HTTP rendezvous as per MSC3886 "org.matrix.msc3886": self.config.experimental.msc3886_endpoint is not None, + # Adds support for relation-based redactions as per MSC3912. + "org.matrix.msc3912": self.config.experimental.msc3912_enabled, }, }, ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c022510e76..ca431002c8 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -295,6 +295,42 @@ class RelationsWorkerStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + async def get_all_relations_for_event_with_types( + self, + event_id: str, + relation_types: List[str], + ) -> List[str]: + """Get the event IDs of all events that have a relation to the given event with + one of the given relation types. + + Args: + event_id: The event for which to look for related events. + relation_types: The types of relations to look for. + + Returns: + A list of the IDs of the events that relate to the given event with one of + the given relation types. + """ + + def get_all_relation_ids_for_event_with_types_txn( + txn: LoggingTransaction, + ) -> List[str]: + rows = self.db_pool.simple_select_many_txn( + txn=txn, + table="event_relations", + column="relation_type", + iterable=relation_types, + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ) + + return [row["event_id"] for row in rows] + + return await self.db_pool.runInteraction( + desc="get_all_relation_ids_for_event_with_types", + func=get_all_relation_ids_for_event_with_types_txn, + ) + async def event_includes_relation(self, event_id: str) -> bool: """Check if the given event relates to another event. diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index be4c67d68e..5dfe44defb 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, room, sync from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class RedactionsTestCase(HomeserverTestCase): @@ -67,7 +68,12 @@ class RedactionsTestCase(HomeserverTestCase): ) def _redact_event( - self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + self, + access_token: str, + room_id: str, + event_id: str, + expect_code: int = 200, + with_relations: Optional[List[str]] = None, ) -> JsonDict: """Helper function to send a redaction event. @@ -75,7 +81,13 @@ class RedactionsTestCase(HomeserverTestCase): """ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) - channel = self.make_request("POST", path, content={}, access_token=access_token) + request_content = {} + if with_relations: + request_content["org.matrix.msc3912.with_relations"] = with_relations + + channel = self.make_request( + "POST", path, request_content, access_token=access_token + ) self.assertEqual(channel.code, expect_code) return channel.json_body @@ -201,3 +213,256 @@ class RedactionsTestCase(HomeserverTestCase): # These should all succeed, even though this would be denied by # the standard message ratelimiter self._redact_event(self.mod_access_token, self.room_id, msg_id) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations(self) -> None: + """Tests that we can redact the relations of an event at the same time as the + event itself. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "hello"}, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send an edit to this root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": " * hello world", + "m.new_content": { + "body": "hello world", + "msgtype": "m.text", + }, + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.REPLACE, + }, + "msgtype": "m.text", + }, + tok=self.mod_access_token, + ) + edit_event_id = res["event_id"] + + # Also send a threaded message whose root is the same as the edit's. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Also send a reaction, again with the same root. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Reaction, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": root_event_id, + "key": "👍", + } + }, + tok=self.mod_access_token, + ) + reaction_event_id = res["event_id"] + + # Redact the root event, specifying that we also want to delete events that + # relate to it with m.replace. + self._redact_event( + self.mod_access_token, + self.room_id, + root_event_id, + with_relations=[ + RelationTypes.REPLACE, + RelationTypes.THREAD, + ], + ) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the edit got redacted. + event_dict = self.helper.get_event( + self.room_id, edit_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the threaded message got redacted. + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the reaction did not get redacted. + event_dict = self.helper.get_event( + self.room_id, reaction_event_id, self.mod_access_token + ) + self.assertNotIn("redacted_because", event_dict, event_dict) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_no_perms(self) -> None: + """Tests that, when redacting a message along with its relations, if not all + the related messages can be redacted because of insufficient permissions, the + server still redacts all the ones that can be. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.other_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message, this one from the moderator. We do this for the + # first message with the m.thread relation (and not the last one) to ensure + # that, when the server fails to redact it, it doesn't stop there, and it + # instead goes on to redact the other one. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + first_threaded_event_id = res["event_id"] + + # Send a second threaded message, this time from the user who'll perform the + # redaction. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 2", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.other_access_token, + ) + second_threaded_event_id = res["event_id"] + + # Redact the thread's root, and request that all threaded messages are also + # redacted. Send that request from the non-mod user, so that the first threaded + # event cannot be redacted. + self._redact_event( + self.other_access_token, + self.room_id, + root_event_id, + with_relations=[RelationTypes.THREAD], + ) + + # Check that the thread root got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the last message in the thread got redacted, despite failing to + # redact the one before it. + event_dict = self.helper.get_event( + self.room_id, second_threaded_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the message that was sent into the tread by the mod user is not + # redacted. + event_dict = self.helper.get_event( + self.room_id, first_threaded_event_id, self.other_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("message 1", event_dict["content"]["body"]) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_txn_id_reuse(self) -> None: + """Tests that redacting a message using a transaction ID, then reusing the same + transaction ID but providing an additional list of relations to redact, is + effectively a no-op. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "I'm in a thread!", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Send a first redaction request which redacts only the root event. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Send a second redaction request which redacts the root event as well as + # threaded messages. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={"org.matrix.msc3912.with_relations": [RelationTypes.THREAD]}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict) + + # Check that the threaded message didn't get redacted (since that wasn't part of + # the original redaction). + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("I'm in a thread!", event_dict["content"]["body"]) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 706399fae5..8d6f2b6ff9 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -410,6 +410,43 @@ class RestHelper: return channel.json_body + def get_event( + self, + room_id: str, + event_id: str, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: + """Request a specific event from the server. + + Args: + room_id: the room in which the event was sent. + event_id: the event's ID. + tok: the token to request the event with. + expect_code: the expected HTTP status for the response. + + Returns: + The event as a dict. + """ + path = f"/_matrix/client/v3/rooms/{room_id}/event/{event_id}" + if tok: + path = path + f"?access_token={tok}" + + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + path, + ) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def _read_write_state( self, room_id: str, -- cgit 1.4.1 From a4b1f6456276e62b3f4d6b060c289b6413b8a5c2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 4 Nov 2022 18:43:51 +0200 Subject: Fix /refresh endpoint version (#14364) --- changelog.d/14364.bugfix | 1 + synapse/rest/client/login.py | 2 +- tests/rest/client/test_auth.py | 16 ++++++++-------- 3 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14364.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14364.bugfix b/changelog.d/14364.bugfix new file mode 100644 index 0000000000..514bf859bb --- /dev/null +++ b/changelog.d/14364.bugfix @@ -0,0 +1 @@ +Fix refresh token endpoint to be under /r0 and /v3 instead of /v1. Contributed by Tulir @ Beeper. diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 7774f1967d..05706b598c 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -536,7 +536,7 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: class RefreshTokenServlet(RestServlet): - PATTERNS = (re.compile("^/_matrix/client/v1/refresh$"),) + PATTERNS = client_patterns("/refresh$") def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 847294dc8e..208ec44829 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -635,7 +635,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): """ return self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": refresh_token}, ) @@ -724,7 +724,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) @@ -765,7 +765,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) @@ -1002,7 +1002,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This first refresh should work properly first_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1012,7 +1012,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one as well, since the token in the first one was never used second_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1022,7 +1022,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one should not, since the token from the first refresh is not valid anymore third_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1056,7 +1056,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -1068,7 +1068,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # But refreshing from the last valid refresh token still works fifth_refresh_response = self.make_request( "POST", - "/_matrix/client/v1/refresh", + "/_matrix/client/v3/refresh", {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( -- cgit 1.4.1 From 7894251bcea7714b47e3849e509ea717bb18e9f5 Mon Sep 17 00:00:00 2001 From: Shay Date: Mon, 7 Nov 2022 13:38:50 -0800 Subject: Correctly create power level event during initial room creation (#14361) --- changelog.d/14361.bugfix | 1 + synapse/handlers/room.py | 25 +++++++++++++++++++++++-- tests/rest/client/test_rooms.py | 4 ++-- 3 files changed, 26 insertions(+), 4 deletions(-) create mode 100644 changelog.d/14361.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14361.bugfix b/changelog.d/14361.bugfix new file mode 100644 index 0000000000..33ba1d92af --- /dev/null +++ b/changelog.d/14361.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.71.0rc1 where the power level event was incorrectly created during initial room creation. \ No newline at end of file diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f10cfca073..66a50bca6e 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1080,6 +1080,19 @@ class RoomCreationHandler: for_batch: bool, **kwargs: Any, ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: + """ + Creates an event and associated event context. + Args: + etype: the type of event to be created + content: content of the event + for_batch: whether the event is being created for batch persisting. If + bool for_batch is true, this will create an event using the prev_event_ids, + and will create an event context for the event using the parameters state_map + and current_state_group, thus these parameters must be provided in this + case if for_batch is True. The subsequently created event and context + are suitable for being batched up and bulk persisted to the database + with other similarly created events. + """ nonlocal depth nonlocal prev_event @@ -1139,13 +1152,21 @@ class RoomCreationHandler: depth += 1 state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id + # we need the state group of the membership event as it is the current state group + event_to_state = ( + await self._storage_controllers.state.get_state_group_for_events( + [member_event_id] + ) + ) + current_state_group = event_to_state[member_event_id] + events_to_send = [] # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: power_event, power_context = await create_event( - EventTypes.PowerLevels, pl_content, False + EventTypes.PowerLevels, pl_content, True ) current_state_group = power_context._state_group events_to_send.append((power_event, power_context)) @@ -1194,7 +1215,7 @@ class RoomCreationHandler: pl_event, pl_context = await create_event( EventTypes.PowerLevels, power_level_content, - False, + True, ) current_state_group = pl_context._state_group events_to_send.append((pl_event, pl_context)) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 1084d4ad9d..e919e089cb 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -715,7 +715,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(34, channel.resource_usage.db_txn_count) + self.assertEqual(33, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -728,7 +728,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(37, channel.resource_usage.db_txn_count) + self.assertEqual(36, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id -- cgit 1.4.1 From a3623af74e0af0d2f6cbd37b47dc54a1acd314d5 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Date: Fri, 11 Nov 2022 19:38:17 +0400 Subject: Add an Admin API endpoint for looking up users based on 3PID (#14405) --- changelog.d/14405.feature | 1 + docs/admin_api/user_admin_api.md | 39 ++++++++++++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/users.py | 25 +++++++++ tests/rest/admin/test_user.py | 107 ++++++++++++++++++++++++++++++++++----- 5 files changed, 161 insertions(+), 13 deletions(-) create mode 100644 changelog.d/14405.feature (limited to 'tests/rest') diff --git a/changelog.d/14405.feature b/changelog.d/14405.feature new file mode 100644 index 0000000000..d3ba89b597 --- /dev/null +++ b/changelog.d/14405.feature @@ -0,0 +1 @@ +Add an [Admin API](https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/index.html) endpoint for user lookup based on third-party ID (3PID). Contributed by @ashfame. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index c95d6c9b05..880bef4194 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -1197,3 +1197,42 @@ Returns a `404` HTTP status code if no user was found, with a response body like ``` _Added in Synapse 1.68.0._ + + +### Find a user based on their Third Party ID (ThreePID or 3PID) + +The API is: + +``` +GET /_synapse/admin/v1/threepid/$medium/users/$address +``` + +When a user matched the given address for the given medium, an HTTP code `200` with a response body like the following is returned: + +```json +{ + "user_id": "@hello:example.org" +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +- `medium` - Kind of third-party ID, either `email` or `msisdn`. +- `address` - Value of the third-party ID. + +The `address` may have characters that are not URL-safe, so it is advised to URL-encode those parameters. + +**Errors** + +Returns a `404` HTTP status code if no user was found, with a response body like this: + +```json +{ + "errcode":"M_NOT_FOUND", + "error":"User not found" +} +``` + +_Added in Synapse 1.72.0._ diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 885669f9c7..c62ea22116 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -81,6 +81,7 @@ from synapse.rest.admin.users import ( ShadowBanRestServlet, UserAdminServlet, UserByExternalId, + UserByThreePid, UserMembershipRestServlet, UserRegisterServlet, UserRestServletV2, @@ -277,6 +278,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomMessagesRestServlet(hs).register(http_server) RoomTimestampToEventRestServlet(hs).register(http_server) UserByExternalId(hs).register(http_server) + UserByThreePid(hs).register(http_server) # Some servlets only get registered for the main process. if hs.config.worker.worker_app is None: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 15ac2059aa..1951b8a9f2 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -1224,3 +1224,28 @@ class UserByExternalId(RestServlet): raise NotFoundError("User not found") return HTTPStatus.OK, {"user_id": user_id} + + +class UserByThreePid(RestServlet): + """Find a user based on 3PID of a particular medium""" + + PATTERNS = admin_patterns("/threepid/(?P[^/]*)/users/(?P
[^/]*)") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + + async def on_GET( + self, + request: SynapseRequest, + medium: str, + address: str, + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + user_id = await self._store.get_user_id_by_threepid(medium, address) + + if user_id is None: + raise NotFoundError("User not found") + + return HTTPStatus.OK, {"user_id": user_id} diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 63410ffdf1..e8c9457794 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -41,14 +41,12 @@ from tests.unittest import override_config class UserRegisterTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, profile.register_servlets, ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.url = "/_synapse/admin/v1/register" self.registration_handler = Mock() @@ -446,7 +444,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): class UsersListTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -1108,7 +1105,6 @@ class UserDevicesTestCase(unittest.HomeserverTestCase): class DeactivateAccountTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -1382,7 +1378,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): class UserRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -2803,7 +2798,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): class UserMembershipRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -2960,7 +2954,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): class PushersRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -3089,7 +3082,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase): class UserMediaRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -3881,7 +3873,6 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): ], ) class WhoisRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -3961,7 +3952,6 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): class ShadowBanRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4042,7 +4032,6 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): class RateLimitTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4268,7 +4257,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase): class AccountDataTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4358,7 +4346,6 @@ class AccountDataTestCase(unittest.HomeserverTestCase): class UsersByExternalIdTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -4442,3 +4429,97 @@ class UsersByExternalIdTestCase(unittest.HomeserverTestCase): {"user_id": self.other_user}, channel.json_body, ) + + +class UsersByThreePidTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.get_success( + self.store.user_add_threepid( + self.other_user, "email", "user@email.com", 1, 1 + ) + ) + self.get_success( + self.store.user_add_threepid(self.other_user, "msidn", "+1-12345678", 1, 1) + ) + + def test_no_auth(self) -> None: + """Try to look up a user without authentication.""" + url = "/_synapse/admin/v1/threepid/email/users/user%40email.com" + + channel = self.make_request( + "GET", + url, + ) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_medium_does_not_exist(self) -> None: + """Tests that both a lookup for a medium that does not exist and a user that + doesn't exist with that third party ID returns a 404""" + # test for unknown medium + url = "/_synapse/admin/v1/threepid/publickey/users/unknown-key" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + # test for unknown user with a known medium + url = "/_synapse/admin/v1/threepid/email/users/unknown" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_success(self) -> None: + """Tests a successful medium + address lookup""" + # test for email medium with encoded value of user@email.com + url = "/_synapse/admin/v1/threepid/email/users/user%40email.com" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + {"user_id": self.other_user}, + channel.json_body, + ) + + # test for msidn medium with encoded value of +1-12345678 + url = "/_synapse/admin/v1/threepid/msidn/users/%2B1-12345678" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + {"user_id": self.other_user}, + channel.json_body, + ) -- cgit 1.4.1 From 1799a54a545618782840a60950ef4b64da9ee24d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 07:26:11 -0500 Subject: Batch fetch bundled annotations (#14491) Avoid an n+1 query problem and fetch the bundled aggregations for m.annotation relations in a single query instead of a query per event. This applies similar logic for as was previously done for edits in 8b309adb436c162510ed1402f33b8741d71fc058 (#11660) and threads in b65acead428653b988351ae8d7b22127a22039cd (#11752). --- changelog.d/14491.feature | 1 + synapse/handlers/relations.py | 197 ++++++++++++++++------------ synapse/storage/databases/main/relations.py | 139 ++++++++++++-------- synapse/util/caches/descriptors.py | 2 +- tests/rest/client/test_relations.py | 4 +- 5 files changed, 202 insertions(+), 141 deletions(-) create mode 100644 changelog.d/14491.feature (limited to 'tests/rest') diff --git a/changelog.d/14491.feature b/changelog.d/14491.feature new file mode 100644 index 0000000000..4fca7282f7 --- /dev/null +++ b/changelog.d/14491.feature @@ -0,0 +1 @@ +Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8e71dda970..ca94239f61 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -13,7 +13,16 @@ # limitations under the License. import enum import logging -from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Tuple, +) import attr @@ -259,48 +268,64 @@ class RelationsHandler: e.msg, ) - async def get_annotations_for_event( - self, - event_id: str, - room_id: str, - limit: int = 5, - ignored_users: FrozenSet[str] = frozenset(), - ) -> List[JsonDict]: - """Get a list of annotations on the event, grouped by event type and + async def get_annotations_for_events( + self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() + ) -> Dict[str, List[JsonDict]]: + """Get a list of annotations to the given events, grouped by event type and aggregation key, sorted by count. - This is used e.g. to get the what and how many reactions have happend + This is used e.g. to get the what and how many reactions have happened on an event. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. ignored_users: The users ignored by the requesting user. Returns: - List of groups of annotations that match. Each row is a dict with - `type`, `key` and `count` fields. + A map of event IDs to a list of groups of annotations that match. + Each entry is a dict with `type`, `key` and `count` fields. """ # Get the base results for all users. - full_results = await self._main_store.get_aggregation_groups_for_event( - event_id, room_id, limit + full_results = await self._main_store.get_aggregation_groups_for_events( + event_ids ) + # Avoid additional logic if there are no ignored users. + if not ignored_users: + return { + event_id: results + for event_id, results in full_results.items() + if results + } + # Then subtract off the results for any ignored users. ignored_results = await self._main_store.get_aggregation_groups_for_users( - event_id, room_id, limit, ignored_users + [event_id for event_id, results in full_results.items() if results], + ignored_users, ) - filtered_results = [] - for result in full_results: - key = (result["type"], result["key"]) - if key in ignored_results: - result = result.copy() - result["count"] -= ignored_results[key] - if result["count"] <= 0: - continue - filtered_results.append(result) + filtered_results = {} + for event_id, results in full_results.items(): + # If no annotations, skip. + if not results: + continue + + # If there are not ignored results for this event, copy verbatim. + if event_id not in ignored_results: + filtered_results[event_id] = results + continue + + # Otherwise, subtract out the ignored results. + event_ignored_results = ignored_results[event_id] + for result in results: + key = (result["type"], result["key"]) + if key in event_ignored_results: + # Ensure to not modify the cache. + result = result.copy() + result["count"] -= event_ignored_results[key] + if result["count"] <= 0: + continue + filtered_results.setdefault(event_id, []).append(result) return filtered_results @@ -366,59 +391,62 @@ class RelationsHandler: results = {} for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event = summary - - # Subtract off the count of any ignored users. - for ignored_user in ignored_users: - thread_count -= ignored_results.get((event_id, ignored_user), 0) - - # This is gnarly, but if the latest event is from an ignored user, - # attempt to find one that isn't from an ignored user. - if latest_thread_event.sender in ignored_users: - room_id = latest_thread_event.room_id - - # If the root event is not found, something went wrong, do - # not include a summary of the thread. - event = await self._event_handler.get_event(user, room_id, event_id) - if event is None: - continue + # If no thread, skip. + if not summary: + continue - potential_events, _ = await self.get_relations_for_event( - event_id, - event, - room_id, - RelationTypes.THREAD, - ignored_users, - ) + thread_count, latest_thread_event = summary - # If all found events are from ignored users, do not include - # a summary of the thread. - if not potential_events: - continue + # Subtract off the count of any ignored users. + for ignored_user in ignored_users: + thread_count -= ignored_results.get((event_id, ignored_user), 0) - # The *last* event returned is the one that is cared about. - event = await self._event_handler.get_event( - user, room_id, potential_events[-1].event_id - ) - # It is unexpected that the event will not exist. - if event is None: - logger.warning( - "Unable to fetch latest event in a thread with event ID: %s", - potential_events[-1].event_id, - ) - continue - latest_thread_event = event - - results[event_id] = _ThreadAggregation( - latest_event=latest_thread_event, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=events_by_id[event_id].sender == user_id - or participated[event_id], + # This is gnarly, but if the latest event is from an ignored user, + # attempt to find one that isn't from an ignored user. + if latest_thread_event.sender in ignored_users: + room_id = latest_thread_event.room_id + + # If the root event is not found, something went wrong, do + # not include a summary of the thread. + event = await self._event_handler.get_event(user, room_id, event_id) + if event is None: + continue + + potential_events, _ = await self.get_relations_for_event( + event_id, + event, + room_id, + RelationTypes.THREAD, + ignored_users, ) + # If all found events are from ignored users, do not include + # a summary of the thread. + if not potential_events: + continue + + # The *last* event returned is the one that is cared about. + event = await self._event_handler.get_event( + user, room_id, potential_events[-1].event_id + ) + # It is unexpected that the event will not exist. + if event is None: + logger.warning( + "Unable to fetch latest event in a thread with event ID: %s", + potential_events[-1].event_id, + ) + continue + latest_thread_event = event + + results[event_id] = _ThreadAggregation( + latest_event=latest_thread_event, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=events_by_id[event_id].sender == user_id + or participated[event_id], + ) + return results @trace @@ -496,17 +524,18 @@ class RelationsHandler: # (as that is what makes it part of the thread). relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD - # Fetch other relations per event. - for event in events_by_id.values(): - # Fetch any annotations (ie, reactions) to bundle with this event. - annotations = await self.get_annotations_for_event( - event.event_id, event.room_id, ignored_users=ignored_users - ) + # Fetch any annotations (ie, reactions) to bundle with this event. + annotations_by_event_id = await self.get_annotations_for_events( + events_by_id.keys(), ignored_users=ignored_users + ) + for event_id, annotations in annotations_by_event_id.items(): if annotations: - results.setdefault( - event.event_id, BundledAggregations() - ).annotations = {"chunk": annotations} + results.setdefault(event_id, BundledAggregations()).annotations = { + "chunk": annotations + } + # Fetch other relations per event. + for event in events_by_id.values(): # Fetch any references to bundle with this event. references, next_token = await self.get_relations_for_event( event.event_id, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index ca431002c8..f96a16956a 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -20,6 +20,7 @@ from typing import ( FrozenSet, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -394,106 +395,136 @@ class RelationsWorkerStore(SQLBaseStore): ) return result is not None - @cached(tree=True) - async def get_aggregation_groups_for_event( - self, event_id: str, room_id: str, limit: int = 5 - ) -> List[JsonDict]: - """Get a list of annotations on the event, grouped by event type and + @cached() + async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_aggregation_groups_for_event", list_name="event_ids" + ) + async def get_aggregation_groups_for_events( + self, event_ids: Collection[str] + ) -> Mapping[str, Optional[List[JsonDict]]]: + """Get a list of annotations on the given events, grouped by event type and aggregation key, sorted by count. This is used e.g. to get the what and how many reactions have happend on an event. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. Returns: - List of groups of annotations that match. Each row is a dict with - `type`, `key` and `count` fields. + A map of event IDs to a list of groups of annotations that match. + Each entry is a dict with `type`, `key` and `count` fields. """ + # The number of entries to return per event ID. + limit = 5 - args = [ - event_id, - room_id, - RelationTypes.ANNOTATION, - limit, - ] + clause, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.ANNOTATION) - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? - GROUP BY relation_type, type, aggregation_key - ORDER BY COUNT(*) DESC - LIMIT ? + sql = f""" + SELECT + relates_to_id, + annotation.type, + aggregation_key, + COUNT(DISTINCT annotation.sender) + FROM events AS annotation + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = annotation.room_id + WHERE + {clause} + AND relation_type = ? + GROUP BY relates_to_id, annotation.type, aggregation_key + ORDER BY relates_to_id, COUNT(*) DESC """ - def _get_aggregation_groups_for_event_txn( + def _get_aggregation_groups_for_events_txn( txn: LoggingTransaction, - ) -> List[JsonDict]: + ) -> Mapping[str, List[JsonDict]]: txn.execute(sql, args) - return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] + result: Dict[str, List[JsonDict]] = {} + for event_id, type, key, count in cast( + List[Tuple[str, str, str, int]], txn + ): + event_results = result.setdefault(event_id, []) + + # Limit the number of results per event ID. + if len(event_results) == limit: + continue + + event_results.append({"type": type, "key": key, "count": count}) + + return result return await self.db_pool.runInteraction( - "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn ) async def get_aggregation_groups_for_users( - self, - event_id: str, - room_id: str, - limit: int, - users: FrozenSet[str] = frozenset(), - ) -> Dict[Tuple[str, str], int]: + self, event_ids: Collection[str], users: FrozenSet[str] + ) -> Dict[str, Dict[Tuple[str, str], int]]: """Fetch the partial aggregations for an event for specific users. This is used, in conjunction with get_aggregation_groups_for_event, to remove information from the results for ignored users. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. users: The users to fetch information for. Returns: - A map of (event type, aggregation key) to a count of users. + A map of event ID to a map of (event type, aggregation key) to a + count of users. """ if not users: return {} - args: List[Union[str, int]] = [ - event_id, - room_id, - RelationTypes.ANNOTATION, - ] + events_sql, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) users_sql, users_args = make_in_list_sql_clause( - self.database_engine, "sender", users + self.database_engine, "annotation.sender", users ) args.extend(users_args) + args.append(RelationTypes.ANNOTATION) sql = f""" - SELECT type, aggregation_key, COUNT(DISTINCT sender) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql} - GROUP BY relation_type, type, aggregation_key - ORDER BY COUNT(*) DESC - LIMIT ? + SELECT + relates_to_id, + annotation.type, + aggregation_key, + COUNT(DISTINCT annotation.sender) + FROM events AS annotation + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = annotation.room_id + WHERE {events_sql} AND {users_sql} AND relation_type = ? + GROUP BY relates_to_id, annotation.type, aggregation_key + ORDER BY relates_to_id, COUNT(*) DESC """ def _get_aggregation_groups_for_users_txn( txn: LoggingTransaction, - ) -> Dict[Tuple[str, str], int]: - txn.execute(sql, args + [limit]) + ) -> Dict[str, Dict[Tuple[str, str], int]]: + txn.execute(sql, args) - return {(row[0], row[1]): row[2] for row in txn} + result: Dict[str, Dict[Tuple[str, str], int]] = {} + for event_id, type, key, count in cast( + List[Tuple[str, str, str, int]], txn + ): + result.setdefault(event_id, {})[(type, key)] = count + + return result return await self.db_pool.runInteraction( "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 75428d19ba..72227359b9 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -503,7 +503,7 @@ def cachedList( is specified as a list that is iterated through to lookup keys in the original cache. A new tuple consisting of the (deduplicated) keys that weren't in the cache gets passed to the original function, which is expected to results - in a map of key to value for each passed value. THe new results are stored in the + in a map of key to value for each passed value. The new results are stored in the original cache. Note that any missing values are cached as None. Args: diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index e3d801f7a8..2d2b683548 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1108,7 +1108,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # @@ -1170,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) def test_nested_thread(self) -> None: """ -- cgit 1.4.1 From 6d7523ef1484ec56f4a6dffdd2ea3d8736b4cc98 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 09:41:09 -0500 Subject: Batch fetch bundled references (#14508) Avoid an n+1 query problem and fetch the bundled aggregations for m.reference relations in a single query instead of a query per event. This applies similar logic for as was previously done for edits in 8b309adb436c162510ed1402f33b8741d71fc058 (#11660; threads in b65acead428653b988351ae8d7b22127a22039cd (#11752); and annotations in 1799a54a545618782840a60950ef4b64da9ee24d (#14491). --- changelog.d/14508.feature | 1 + synapse/handlers/relations.py | 128 +++++++++++++--------------- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/events.py | 4 + synapse/storage/databases/main/relations.py | 74 ++++++++++++++-- tests/rest/client/test_relations.py | 4 +- 6 files changed, 133 insertions(+), 79 deletions(-) create mode 100644 changelog.d/14508.feature (limited to 'tests/rest') diff --git a/changelog.d/14508.feature b/changelog.d/14508.feature new file mode 100644 index 0000000000..4fca7282f7 --- /dev/null +++ b/changelog.d/14508.feature @@ -0,0 +1 @@ +Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index ca94239f61..8414be5879 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -13,16 +13,7 @@ # limitations under the License. import enum import logging -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - FrozenSet, - Iterable, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional import attr @@ -32,7 +23,7 @@ from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StreamToken, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -181,40 +172,6 @@ class RelationsHandler: return return_value - async def get_relations_for_event( - self, - event_id: str, - event: EventBase, - room_id: str, - relation_type: str, - ignored_users: FrozenSet[str] = frozenset(), - ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: - """Get a list of events which relate to an event, ordered by topological ordering. - - Args: - event_id: Fetch events that relate to this event ID. - event: The matching EventBase to event_id. - room_id: The room the event belongs to. - relation_type: The type of relation. - ignored_users: The users ignored by the requesting user. - - Returns: - List of event IDs that match relations requested. The rows are of - the form `{"event_id": "..."}`. - """ - - # Call the underlying storage method, which is cached. - related_events, next_token = await self._main_store.get_relations_for_event( - event_id, event, room_id, relation_type, direction="f" - ) - - # Filter out ignored users and convert to the expected format. - related_events = [ - event for event in related_events if event.sender not in ignored_users - ] - - return related_events, next_token - async def redact_events_related_to( self, requester: Requester, @@ -329,6 +286,46 @@ class RelationsHandler: return filtered_results + async def get_references_for_events( + self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() + ) -> Dict[str, List[_RelatedEvent]]: + """Get a list of references to the given events. + + Args: + event_ids: Fetch events that relate to this event ID. + ignored_users: The users ignored by the requesting user. + + Returns: + A map of event IDs to a list related events. + """ + + related_events = await self._main_store.get_references_for_events(event_ids) + + # Avoid additional logic if there are no ignored users. + if not ignored_users: + return { + event_id: results + for event_id, results in related_events.items() + if results + } + + # Filter out ignored users. + results = {} + for event_id, events in related_events.items(): + # If no references, skip. + if not events: + continue + + # Filter ignored users out. + events = [event for event in events if event.sender not in ignored_users] + # If there are no events left, skip this event. + if not events: + continue + + results[event_id] = events + + return results + async def _get_threads_for_events( self, events_by_id: Dict[str, EventBase], @@ -412,14 +409,18 @@ class RelationsHandler: if event is None: continue - potential_events, _ = await self.get_relations_for_event( - event_id, - event, - room_id, - RelationTypes.THREAD, - ignored_users, + # Attempt to find another event to use as the latest event. + potential_events, _ = await self._main_store.get_relations_for_event( + event_id, event, room_id, RelationTypes.THREAD, direction="f" ) + # Filter out ignored users. + potential_events = [ + event + for event in potential_events + if event.sender not in ignored_users + ] + # If all found events are from ignored users, do not include # a summary of the thread. if not potential_events: @@ -534,27 +535,16 @@ class RelationsHandler: "chunk": annotations } - # Fetch other relations per event. - for event in events_by_id.values(): - # Fetch any references to bundle with this event. - references, next_token = await self.get_relations_for_event( - event.event_id, - event, - event.room_id, - RelationTypes.REFERENCE, - ignored_users=ignored_users, - ) + # Fetch any references to bundle with this event. + references_by_event_id = await self.get_references_for_events( + events_by_id.keys(), ignored_users=ignored_users + ) + for event_id, references in references_by_event_id.items(): if references: - aggregations = results.setdefault(event.event_id, BundledAggregations()) - aggregations.references = { + results.setdefault(event_id, BundledAggregations()).references = { "chunk": [{"event_id": ev.event_id} for ev in references] } - if next_token: - aggregations.references["next_batch"] = await next_token.to_string( - self._main_store - ) - # Fetch any edits (but not for redacted events). # # Note that there is no use in limiting edits by ignored users since the @@ -600,7 +590,7 @@ class RelationsHandler: room_id, requester, allow_departed_users=True ) - # Note that ignored users are not passed into get_relations_for_event + # Note that ignored users are not passed into get_threads # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). thread_roots, next_batch = await self._main_store.get_threads( diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index ddb7397714..a58668a380 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) + self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) self._attempt_to_invalidate_cache( "get_aggregation_groups_for_event", (relates_to,) ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index d68f127f9b..0f097a2927 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2049,6 +2049,10 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REFERENCE: + self.store._invalidate_cache_and_stream( + txn, self.store.get_references_for_event, (redacted_relates_to,) + ) if rel_type == RelationTypes.REPLACE: self.store._invalidate_cache_and_stream( txn, self.store.get_applicable_edit, (redacted_relates_to,) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index f96a16956a..aea96e9d24 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -82,8 +82,6 @@ class _RelatedEvent: event_id: str # The sender of the related event. sender: str - topological_ordering: Optional[int] - stream_ordering: int class RelationsWorkerStore(SQLBaseStore): @@ -246,13 +244,17 @@ class RelationsWorkerStore(SQLBaseStore): txn.execute(sql, where_args + [limit + 1]) events = [] - for event_id, relation_type, sender, topo_ordering, stream_ordering in txn: + topo_orderings: List[int] = [] + stream_orderings: List[int] = [] + for event_id, relation_type, sender, topo_ordering, stream_ordering in cast( + List[Tuple[str, str, str, int, int]], txn + ): # Do not include edits for redacted events as they leak event # content. if not is_redacted or relation_type != RelationTypes.REPLACE: - events.append( - _RelatedEvent(event_id, sender, topo_ordering, stream_ordering) - ) + events.append(_RelatedEvent(event_id, sender)) + topo_orderings.append(topo_ordering) + stream_orderings.append(stream_ordering) # If there are more events, generate the next pagination key from the # last event returned. @@ -261,9 +263,11 @@ class RelationsWorkerStore(SQLBaseStore): # Instead of using the last row (which tells us there is more # data), use the last row to be returned. events = events[:limit] + topo_orderings = topo_orderings[:limit] + stream_orderings = stream_orderings[:limit] - topo = events[-1].topological_ordering - token = events[-1].stream_ordering + topo = topo_orderings[-1] + token = stream_orderings[-1] if direction == "b": # Tokens are positions between events. # This token points *after* the last event in the chunk. @@ -530,6 +534,60 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn ) + @cached() + async def get_references_for_event(self, event_id: str) -> List[JsonDict]: + raise NotImplementedError() + + @cachedList(cached_method_name="get_references_for_event", list_name="event_ids") + async def get_references_for_events( + self, event_ids: Collection[str] + ) -> Mapping[str, Optional[List[_RelatedEvent]]]: + """Get a list of references to the given events. + + Args: + event_ids: Fetch events that relate to these event IDs. + + Returns: + A map of event IDs to a list of related event IDs (and their senders). + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.REFERENCE) + + sql = f""" + SELECT relates_to_id, ref.event_id, ref.sender + FROM events AS ref + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = ref.room_id + WHERE + {clause} + AND relation_type = ? + ORDER BY ref.topological_ordering, ref.stream_ordering + """ + + def _get_references_for_events_txn( + txn: LoggingTransaction, + ) -> Mapping[str, List[_RelatedEvent]]: + txn.execute(sql, args) + + result: Dict[str, List[_RelatedEvent]] = {} + for relates_to_id, event_id, sender in cast( + List[Tuple[str, str, str]], txn + ): + result.setdefault(relates_to_id, []).append( + _RelatedEvent(event_id, sender) + ) + + return result + + return await self.db_pool.runInteraction( + "_get_references_for_events_txn", _get_references_for_events_txn + ) + @cached() def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 2d2b683548..b86f341ff5 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1108,7 +1108,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # @@ -1170,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7) def test_nested_thread(self) -> None: """ -- cgit 1.4.1 From 6d47b7e32589e816eb766446cc1ff19ea73fc7c1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 14:08:04 -0500 Subject: Add a type hint for `get_device_handler()` and fix incorrect types. (#14055) This was the last untyped handler from the HomeServer object. Since it was being treated as Any (and thus unchecked) it was being used incorrectly in a few places. --- changelog.d/14055.misc | 1 + synapse/handlers/deactivate_account.py | 4 +++ synapse/handlers/device.py | 65 ++++++++++++++++++++++++++-------- synapse/handlers/e2e_keys.py | 61 ++++++++++++++++--------------- synapse/handlers/register.py | 4 +++ synapse/handlers/set_password.py | 6 +++- synapse/handlers/sso.py | 9 +++++ synapse/module_api/__init__.py | 10 +++++- synapse/replication/http/devices.py | 11 ++++-- synapse/rest/admin/__init__.py | 26 ++++++++------ synapse/rest/admin/devices.py | 13 +++++-- synapse/rest/client/devices.py | 17 ++++++--- synapse/rest/client/logout.py | 9 +++-- synapse/server.py | 2 +- tests/handlers/test_device.py | 19 ++++++---- tests/rest/admin/test_device.py | 5 ++- 16 files changed, 185 insertions(+), 77 deletions(-) create mode 100644 changelog.d/14055.misc (limited to 'tests/rest') diff --git a/changelog.d/14055.misc b/changelog.d/14055.misc new file mode 100644 index 0000000000..02980bc528 --- /dev/null +++ b/changelog.d/14055.misc @@ -0,0 +1 @@ +Add missing type hints to `HomeServer`. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 816e1a6d79..d74d135c0c 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError +from synapse.handlers.device import DeviceHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import Codes, Requester, UserID, create_requester @@ -76,6 +77,9 @@ class DeactivateAccountHandler: True if identity server supports removing threepids, otherwise False. """ + # This can only be called on the main process. + assert isinstance(self._device_handler, DeviceHandler) + # Check if this user can be deactivated if not await self._third_party_rules.check_can_deactivate_user( user_id, by_admin diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index da3ddafeae..b1e55e1b9e 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 class DeviceWorkerHandler: + device_list_updater: "DeviceListWorkerUpdater" + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs @@ -76,6 +78,8 @@ class DeviceWorkerHandler: self.server_name = hs.hostname self._msc3852_enabled = hs.config.experimental.msc3852_enabled + self.device_list_updater = DeviceListWorkerUpdater(hs) + @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ @@ -99,6 +103,19 @@ class DeviceWorkerHandler: log_kv(device_map) return devices + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + return await self.store.get_dehydrated_device(user_id) + @trace async def get_device(self, user_id: str, device_id: str) -> JsonDict: """Retrieve the given device @@ -127,7 +144,7 @@ class DeviceWorkerHandler: @cancellable async def get_device_changes_in_shared_rooms( self, user_id: str, room_ids: Collection[str], from_token: StreamToken - ) -> Collection[str]: + ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. """ @@ -320,6 +337,8 @@ class DeviceWorkerHandler: class DeviceHandler(DeviceWorkerHandler): + device_list_updater: "DeviceListUpdater" + def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler): await self.delete_devices(user_id, [old_device_id]) return device_id - async def get_dehydrated_device( - self, user_id: str - ) -> Optional[Tuple[str, JsonDict]]: - """Retrieve the information for a dehydrated device. - - Args: - user_id: the user whose dehydrated device we are looking for - Returns: - a tuple whose first item is the device ID, and the second item is - the dehydrated device information - """ - return await self.store.get_dehydrated_device(user_id) - async def rehydrate_device( self, user_id: str, access_token: str, device_id: str ) -> dict: @@ -882,7 +888,36 @@ def _update_device_from_client_ips( ) -class DeviceListUpdater: +class DeviceListWorkerUpdater: + "Handles incoming device list updates from federation and contacts the main process over replication" + + def __init__(self, hs: "HomeServer"): + from synapse.replication.http.devices import ( + ReplicationUserDevicesResyncRestServlet, + ) + + self._user_device_resync_client = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) + ) + + async def user_device_resync( + self, user_id: str, mark_failed_as_stale: bool = True + ) -> Optional[JsonDict]: + """Fetches all devices for a user and updates the device cache with them. + + Args: + user_id: The user's id whose device_list will be updated. + mark_failed_as_stale: Whether to mark the user's device list as stale + if the attempt to resync failed. + Returns: + A dict with device info as under the "devices" in the result of this + request: + https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + """ + return await self._user_device_resync_client(user_id=user_id) + + +class DeviceListUpdater(DeviceListWorkerUpdater): "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index bf1221f523..5fe102e2f2 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -27,9 +27,9 @@ from twisted.internet import defer from synapse.api.constants import EduTypes from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace -from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( JsonDict, UserID, @@ -56,27 +56,23 @@ class E2eKeysHandler: self.is_mine = hs.is_mine self.clock = hs.get_clock() - self._edu_updater = SigningKeyEduUpdater(hs, self) - federation_registry = hs.get_federation_registry() - self._is_master = hs.config.worker.worker_app is None - if not self._is_master: - self._user_device_resync_client = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) - ) - else: + is_master = hs.config.worker.worker_app is None + if is_master: + edu_updater = SigningKeyEduUpdater(hs) + # Only register this edu handler on master as it requires writing # device updates to the db federation_registry.register_edu_handler( EduTypes.SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # also handle the unstable version # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # doesn't really work as part of the generic query API, because the @@ -319,14 +315,13 @@ class E2eKeysHandler: # probably be tracking their device lists. However, we haven't # done an initial sync on the device list so we do it now. try: - if self._is_master: - resync_results = await self.device_handler.device_list_updater.user_device_resync( + resync_results = ( + await self.device_handler.device_list_updater.user_device_resync( user_id ) - else: - resync_results = await self._user_device_resync_client( - user_id=user_id - ) + ) + if resync_results is None: + raise ValueError("Device resync failed") # Add the device keys to the results. user_devices = resync_results["devices"] @@ -605,6 +600,8 @@ class E2eKeysHandler: async def upload_keys_for_user( self, user_id: str, device_id: str, keys: JsonDict ) -> JsonDict: + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) time_now = self.clock.time_msec() @@ -732,6 +729,8 @@ class E2eKeysHandler: user_id: the user uploading the keys keys: the signing keys """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) # if a master key is uploaded, then check it. Otherwise, load the # stored master key, to check signatures on other keys @@ -823,6 +822,9 @@ class E2eKeysHandler: Raises: SynapseError: if the signatures dict is not valid. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + failures = {} # signatures to be stored. Each item will be a SignatureListItem @@ -1200,6 +1202,9 @@ class E2eKeysHandler: A tuple of the retrieved key content, the key's ID and the matching VerifyKey. If the key cannot be retrieved, all values in the tuple will instead be None. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + try: remote_result = await self.federation.query_user_devices( user.domain, user.to_string() @@ -1396,11 +1401,14 @@ class SignatureListItem: class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() - self.e2e_keys_handler = e2e_keys_handler + + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler self._remote_edu_linearizer = Linearizer(name="remote_signing_key") @@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater: user_id: the user whose updates we are processing """ - device_handler = self.e2e_keys_handler.device_handler - device_list_updater = device_handler.device_list_updater - async with self._remote_edu_linearizer.queue(user_id): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: @@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater: logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = ( - await device_list_updater.process_cross_signing_key_update( - user_id, - master_key, - self_signing_key, - ) + new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update( + user_id, + master_key, + self_signing_key, ) device_ids = device_ids + new_device_ids - await device_handler.notify_device_update(user_id, device_ids) + await self._device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ca1c7a1866..6307fa9c5d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -38,6 +38,7 @@ from synapse.api.errors import ( ) from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.register import ( @@ -841,6 +842,9 @@ class RegistrationHandler: refresh_token = None refresh_token_id = None + # This can only run on the main process. + assert isinstance(self.device_handler, DeviceHandler) + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 73861bbd40..bd9d0bb34b 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.types import Requester if TYPE_CHECKING: @@ -29,7 +30,10 @@ class SetPasswordHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + # This can only be instantiated on the main process. + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler async def set_password( self, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 749d7e93b0..e1c0bff1b2 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -37,6 +37,7 @@ from twisted.web.server import Request from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.config.sso import SsoAttributeRequirement +from synapse.handlers.device import DeviceHandler from synapse.handlers.register import init_counters_for_auth_provider from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent @@ -1035,6 +1036,8 @@ class SsoHandler: ) -> None: """Revoke any devices and in-flight logins tied to a provider session. + Can only be called from the main process. + Args: auth_provider_id: A unique identifier for this SSO provider, e.g. "oidc" or "saml". @@ -1042,6 +1045,12 @@ class SsoHandler: expected_user_id: The user we're expecting to logout. If set, it will ignore sessions belonging to other users and log an error. """ + + # It is expected that this is the main process. + assert isinstance( + self._device_handler, DeviceHandler + ), "revoking SSO sessions can only be called on the main process" + # Invalidate any running user-mapping sessions to_delete = [] for session_id, session in self._username_mapping_sessions.items(): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 1adc1fd64f..96a661177a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -86,6 +86,7 @@ from synapse.handlers.auth import ( ON_LOGGED_OUT_CALLBACK, AuthHandler, ) +from synapse.handlers.device import DeviceHandler from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.http.client import SimpleHttpClient from synapse.http.server import ( @@ -207,6 +208,7 @@ class ModuleApi: self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() self._push_rules_handler = hs.get_push_rules_handler() + self._device_handler = hs.get_device_handler() self.custom_template_dir = hs.config.server.custom_template_directory try: @@ -784,6 +786,8 @@ class ModuleApi: ) -> Generator["defer.Deferred[Any]", Any, None]: """Invalidate an access token for a user + Can only be called from the main process. + Added in Synapse v0.25.0. Args: @@ -796,6 +800,10 @@ class ModuleApi: Raises: synapse.api.errors.AuthError: the access token is invalid """ + assert isinstance( + self._device_handler, DeviceHandler + ), "invalidate_access_token can only be called on the main process" + # see if the access token corresponds to a device user_info = yield defer.ensureDeferred( self._auth.get_user_by_access_token(access_token) @@ -805,7 +813,7 @@ class ModuleApi: if device_id: # delete the device, which will also delete its access tokens yield defer.ensureDeferred( - self._hs.get_device_handler().delete_devices(user_id, [device_id]) + self._device_handler.delete_devices(user_id, [device_id]) ) else: # no associated device. Just delete the access token. diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index c21629def8..7c4941c3d3 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from twisted.web.server import Request @@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.device_list_updater = hs.get_device_handler().device_list_updater + from synapse.handlers.device import DeviceHandler + + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_list_updater = handler.device_list_updater + self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -73,7 +78,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): async def _handle_request( # type: ignore[override] self, request: Request, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, Optional[JsonDict]]: user_devices = await self.device_list_updater.user_device_resync(user_id) return 200, user_devices diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c62ea22116..fb73886df0 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: """ Register all the admin servlets. """ + # Admin servlets aren't registered on workers. + if hs.config.worker.worker_app is not None: + return + register_servlets_for_client_rest_resource(hs, http_server) BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) @@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) - DeviceRestServlet(hs).register(http_server) - DevicesRestServlet(hs).register(http_server) - DeleteDevicesRestServlet(hs).register(http_server) UserMediaStatisticsRestServlet(hs).register(http_server) EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) @@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserByExternalId(hs).register(http_server) UserByThreePid(hs).register(http_server) - # Some servlets only get registered for the main process. - if hs.config.worker.worker_app is None: - SendServerNoticeServlet(hs).register(http_server) - BackgroundUpdateEnabledRestServlet(hs).register(http_server) - BackgroundUpdateRestServlet(hs).register(http_server) - BackgroundUpdateStartJobRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) + DeleteDevicesRestServlet(hs).register(http_server) + SendServerNoticeServlet(hs).register(http_server) + BackgroundUpdateEnabledRestServlet(hs).register(http_server) + BackgroundUpdateRestServlet(hs).register(http_server) + BackgroundUpdateStartJobRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( @@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource( """Register only the servlets which need to be exposed on /_matrix/client/xxx""" WhoisRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server) - DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) - ResetPasswordRestServlet(hs).register(http_server) + # The following resources can only be run on the main process. + if hs.config.worker.worker_app is None: + DeactivateAccountRestServlet(hs).register(http_server) + ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index d934880102..3b2f2d9abb 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -16,6 +16,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8f3cbd4ea2..69b803f9f8 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr from synapse.api import errors from synapse.api.errors import NotFoundError +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() class PostBody(RequestBodyModel): @@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() self._msc3852_enabled = hs.config.experimental.msc3852_enabled @@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler class PostBody(RequestBodyModel): device_id: StrictStr diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 23dfa4518f..6d34625ad5 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest @@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) @@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) diff --git a/synapse/server.py b/synapse/server.py index f0a60d0056..5baae2325e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -510,7 +510,7 @@ class HomeServer(metaclass=abc.ABCMeta): ) @cache_in_self - def get_device_handler(self): + def get_device_handler(self) -> DeviceWorkerHandler: if self.config.worker.worker_app: return DeviceWorkerHandler(self) else: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index b8b465d35b..ce7525e29c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -19,7 +19,7 @@ from typing import Optional from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError -from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN +from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler from synapse.server import HomeServer from synapse.util import Clock @@ -32,7 +32,9 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.store = hs.get_datastores().main return hs @@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_is_preserved_if_exists(self) -> None: @@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res2, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_id_is_made_up_if_unspecified(self) -> None: @@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) + assert dev is not None self.assertEqual(dev["display_name"], "display") def test_get_devices_by_user(self) -> None: @@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.registration = hs.get_registration_handler() self.auth = hs.get_auth() self.store = hs.get_datastores().main @@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase): ) ) - retrieved_device_id, device_data = self.get_success( - self.handler.get_dehydrated_device(user_id=user_id) - ) + result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id)) + assert result is not None + retrieved_device_id, device_data = result self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index d52aee8f92..03f2112b07 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes +from synapse.handlers.device import DeviceHandler from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") -- cgit 1.4.1 From f6c74d1cb2ed966802b01a2b037f09ce7a842c18 Mon Sep 17 00:00:00 2001 From: Benjamin Kampmann Date: Thu, 24 Nov 2022 09:10:51 +0000 Subject: Implement message forward pagination from start when no from is given, fixes #12383 (#14149) Fixes https://github.com/matrix-org/synapse/issues/12383 --- changelog.d/14149.bugfix | 1 + synapse/handlers/pagination.py | 6 ++++++ synapse/streams/events.py | 13 +++++++++++++ tests/rest/admin/test_room.py | 40 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+) create mode 100644 changelog.d/14149.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14149.bugfix b/changelog.d/14149.bugfix new file mode 100644 index 0000000000..b31c658266 --- /dev/null +++ b/changelog.d/14149.bugfix @@ -0,0 +1 @@ +Fix #12383: paginate room messages from the start if no from is given. Contributed by @gnunicorn . \ No newline at end of file diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a4ca9cb8b4..c572508a02 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -448,6 +448,12 @@ class PaginationHandler: if pagin_config.from_token: from_token = pagin_config.from_token + elif pagin_config.direction == "f": + from_token = ( + await self.hs.get_event_sources().get_start_token_for_pagination( + room_id + ) + ) else: from_token = ( await self.hs.get_event_sources().get_current_token_for_pagination( diff --git a/synapse/streams/events.py b/synapse/streams/events.py index f331e1af16..619eb7f601 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -73,6 +73,19 @@ class EventSources: ) return token + @trace + async def get_start_token_for_pagination(self, room_id: str) -> StreamToken: + """Get the start token for a given room to be used to paginate + events. + + The returned token does not have the current values for fields other + than `room`, since they are not used during pagination. + + Returns: + The start token for pagination. + """ + return StreamToken.START + @trace async def get_current_token_for_pagination(self, room_id: str) -> StreamToken: """Get the current token for a given room to be used to paginate diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index d156be82b0..e0f5d54aba 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1857,6 +1857,46 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): self.assertIn("chunk", channel.json_body) self.assertIn("end", channel.json_body) + def test_room_messages_backward(self) -> None: + """Test room messages can be retrieved by an admin that isn't in the room.""" + latest_event_id = self.helper.send( + self.room_id, body="message 1", tok=self.user_tok + )["event_id"] + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?dir=b" % (self.room_id,), + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 6, [event["content"] for event in chunk]) + + # in backwards, this is the first event + self.assertEqual(chunk[0]["event_id"], latest_event_id) + + def test_room_messages_forward(self) -> None: + """Test room messages can be retrieved by an admin that isn't in the room.""" + latest_event_id = self.helper.send( + self.room_id, body="message 1", tok=self.user_tok + )["event_id"] + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/messages?dir=f" % (self.room_id,), + access_token=self.admin_user_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 6, [event["content"] for event in chunk]) + + # in forward, this is the last event + self.assertEqual(chunk[5]["event_id"], latest_event_id) + def test_room_messages_purge(self) -> None: """Test room messages can be retrieved by an admin that isn't in the room.""" store = self.hs.get_datastores().main -- cgit 1.4.1 From 8f10c8b054fc970838be9ae6f1f5aea95f166c98 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 28 Nov 2022 15:54:18 -0600 Subject: Move MSC3030 `/timestamp_to_event` endpoint to stable v1 location (#14471) Fix https://github.com/matrix-org/synapse/issues/14390 - Client API: `/_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir=` -> `/_matrix/client/v1/rooms//timestamp_to_event?ts=&dir=` - Federation API: `/_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/?ts=&dir=` -> `/_matrix/federation/v1/timestamp_to_event/?ts=&dir=` Complement test changes: https://github.com/matrix-org/complement/pull/559 --- changelog.d/14471.feature | 1 + docker/complement/conf/workers-shared-extra.yaml.j2 | 2 -- docker/configure_workers_and_start.py | 2 ++ docs/workers.md | 2 ++ scripts-dev/complement.sh | 6 +++--- synapse/config/experimental.py | 3 --- synapse/federation/federation_client.py | 12 +++++++++++- synapse/federation/transport/client.py | 5 ++--- synapse/federation/transport/server/__init__.py | 8 -------- synapse/federation/transport/server/federation.py | 3 +-- synapse/rest/client/room.py | 10 +++------- synapse/rest/client/versions.py | 2 -- tests/rest/client/test_rooms.py | 7 +------ 13 files changed, 26 insertions(+), 37 deletions(-) create mode 100644 changelog.d/14471.feature (limited to 'tests/rest') diff --git a/changelog.d/14471.feature b/changelog.d/14471.feature new file mode 100644 index 0000000000..a0e0c74f1a --- /dev/null +++ b/changelog.d/14471.feature @@ -0,0 +1 @@ +Move MSC3030 `/timestamp_to_event` endpoints to stable `v1` location (`/_matrix/client/v1/rooms//timestamp_to_event?ts=&dir=`, `/_matrix/federation/v1/timestamp_to_event/?ts=&dir=`). diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 883a87159c..ca640c343b 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -100,8 +100,6 @@ experimental_features: # client-side support for partial state in /send_join responses faster_joins: true {% endif %} - # Enable jump to date endpoint - msc3030_enabled: true # Filtering /messages by relation type. msc3874_enabled: true diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index c1e1544536..58c62f2231 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -140,6 +140,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event", "^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms", "^/_matrix/client/(api/v1|r0|v3|unstable/.*)/rooms/.*/aliases", + "^/_matrix/client/v1/rooms/.*/timestamp_to_event$", "^/_matrix/client/(api/v1|r0|v3|unstable)/search", ], "shared_extra_conf": {}, @@ -163,6 +164,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/federation/(v1|v2)/invite/", "^/_matrix/federation/(v1|v2)/query_auth/", "^/_matrix/federation/(v1|v2)/event_auth/", + "^/_matrix/federation/v1/timestamp_to_event/", "^/_matrix/federation/(v1|v2)/exchange_third_party_invite/", "^/_matrix/federation/(v1|v2)/user/devices/", "^/_matrix/federation/(v1|v2)/get_groups_publicised$", diff --git a/docs/workers.md b/docs/workers.md index 27e54c5846..2b65acb5ed 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -191,6 +191,7 @@ information. ^/_matrix/federation/(v1|v2)/send_leave/ ^/_matrix/federation/(v1|v2)/invite/ ^/_matrix/federation/v1/event_auth/ + ^/_matrix/federation/v1/timestamp_to_event/ ^/_matrix/federation/v1/exchange_third_party_invite/ ^/_matrix/federation/v1/user/devices/ ^/_matrix/key/v2/query @@ -218,6 +219,7 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/ ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ + ^/_matrix/client/v1/rooms/.*/timestamp_to_event$ ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ # Encryption requests diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 803c6ce92d..7744b47097 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -162,9 +162,9 @@ else # We only test faster room joins on monoliths, because they are purposefully # being developed without worker support to start with. # - # The tests for importing historical messages (MSC2716) and jump to date (MSC3030) - # also only pass with monoliths, currently. - test_tags="$test_tags,faster_joins,msc2716,msc3030" + # The tests for importing historical messages (MSC2716) also only pass with monoliths, + # currently. + test_tags="$test_tags,faster_joins,msc2716" fi diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d4b71d1673..a503abf364 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -53,9 +53,6 @@ class ExperimentalConfig(Config): # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) - # MSC3030 (Jump to date API endpoint) - self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) - # MSC2409 (this setting only relates to optionally sending to-device messages). # Presence, typing and read receipt EDUs are already sent to application services that # have opted in to receive them. If enabled, this adds to-device messages to that list. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c4c0bc7315..8bccc9c60d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1691,9 +1691,19 @@ class FederationClient(FederationBase): # to return events on *both* sides of the timestamp to # help reconcile the gap faster. _timestamp_to_event_from_destination, + # Since this endpoint is new, we should try other servers before giving up. + # We can safely remove this in a year (remove after 2023-11-16). + failover_on_unknown_endpoint=True, ) return timestamp_to_event_response - except SynapseError: + except SynapseError as e: + logger.warn( + "timestamp_to_event(room_id=%s, timestamp=%s, direction=%s): encountered error when trying to fetch from destinations: %s", + room_id, + timestamp, + direction, + e, + ) return None async def _timestamp_to_event_from_destination( diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index a3cfc701cd..77f1f39cac 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -185,9 +185,8 @@ class TransportLayerClient: Raises: Various exceptions when the request fails """ - path = _create_path( - FEDERATION_UNSTABLE_PREFIX, - "/org.matrix.msc3030/timestamp_to_event/%s", + path = _create_v1_path( + "/timestamp_to_event/%s", room_id, ) diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 50623cd385..2725f53cf6 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -25,7 +25,6 @@ from synapse.federation.transport.server._base import ( from synapse.federation.transport.server.federation import ( FEDERATION_SERVLET_CLASSES, FederationAccountStatusServlet, - FederationTimestampLookupServlet, ) from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( @@ -291,13 +290,6 @@ def register_servlets( ) for servletclass in SERVLET_GROUPS[servlet_group]: - # Only allow the `/timestamp_to_event` servlet if msc3030 is enabled - if ( - servletclass == FederationTimestampLookupServlet - and not hs.config.experimental.msc3030_enabled - ): - continue - # Only allow the `/account_status` servlet if msc3720 is enabled if ( servletclass == FederationAccountStatusServlet diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 205fd16daa..53e77b4bb6 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -218,14 +218,13 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet): `dir` can be `f` or `b` to indicate forwards and backwards in time from the given timestamp. - GET /_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/?ts=&dir= + GET /_matrix/federation/v1/timestamp_to_event/?ts=&dir= { "event_id": ... } """ PATH = "/timestamp_to_event/(?P[^/]*)/?" - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3030" async def on_GET( self, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 91cb791139..636cc62877 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1284,17 +1284,14 @@ class TimestampLookupRestServlet(RestServlet): `dir` can be `f` or `b` to indicate forwards and backwards in time from the given timestamp. - GET /_matrix/client/unstable/org.matrix.msc3030/rooms//timestamp_to_event?ts=&dir= + GET /_matrix/client/v1/rooms//timestamp_to_event?ts=&dir= { "event_id": ... } """ PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc3030" - "/rooms/(?P[^/]*)/timestamp_to_event$" - ), + re.compile("^/_matrix/client/v1/rooms/(?P[^/]*)/timestamp_to_event$"), ) def __init__(self, hs: "HomeServer"): @@ -1421,8 +1418,7 @@ def register_servlets( RoomAliasListServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server) - if hs.config.experimental.msc3030_enabled: - TimestampLookupRestServlet(hs).register(http_server) + TimestampLookupRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if not is_worker: diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 180a11ef88..3c0a90010b 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -101,8 +101,6 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3827.stable": True, # Adds support for importing historical messages as per MSC2716 "org.matrix.msc2716": self.config.experimental.msc2716_enabled, - # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030 - "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above # Support for thread read receipts & notification counts. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index e919e089cb..b4daace556 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -3546,11 +3546,6 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def default_config(self) -> JsonDict: - config = super().default_config() - config["experimental_features"] = {"msc3030_enabled": True} - return config - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self._storage_controllers = self.hs.get_storage_controllers() @@ -3592,7 +3587,7 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - f"/_matrix/client/unstable/org.matrix.msc3030/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}", + f"/_matrix/client/v1/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}", access_token=self.room_owner_tok, ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) -- cgit 1.4.1 From ecb6fe9d9cf8375b760eb727be0e1dec3612e026 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 30 Nov 2022 11:59:57 +0000 Subject: Stop using deprecated `keyIds` param on /key/v2/server (#14525) Fixes #14523. --- changelog.d/14490.feature | 1 + changelog.d/14490.misc | 1 - changelog.d/14525.feature | 1 + synapse/crypto/keyring.py | 107 +++++++++++--------------- tests/crypto/test_keyring.py | 14 +--- tests/rest/key/v2/test_remote_key_resource.py | 5 +- 6 files changed, 47 insertions(+), 82 deletions(-) create mode 100644 changelog.d/14490.feature delete mode 100644 changelog.d/14490.misc create mode 100644 changelog.d/14525.feature (limited to 'tests/rest') diff --git a/changelog.d/14490.feature b/changelog.d/14490.feature new file mode 100644 index 0000000000..c7cb571294 --- /dev/null +++ b/changelog.d/14490.feature @@ -0,0 +1 @@ +Stop using deprecated `keyIds` parameter when calling `/_matrix/key/v2/server`. diff --git a/changelog.d/14490.misc b/changelog.d/14490.misc deleted file mode 100644 index c0a4daa885..0000000000 --- a/changelog.d/14490.misc +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse 0.9 where it would fail to fetch server keys whose IDs contain a forward slash. diff --git a/changelog.d/14525.feature b/changelog.d/14525.feature new file mode 100644 index 0000000000..c7cb571294 --- /dev/null +++ b/changelog.d/14525.feature @@ -0,0 +1 @@ +Stop using deprecated `keyIds` parameter when calling `/_matrix/key/v2/server`. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index ed15f88350..69310d9035 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -14,7 +14,6 @@ import abc import logging -import urllib from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple import attr @@ -813,31 +812,27 @@ class ServerKeyFetcher(BaseV2KeyFetcher): results = {} - async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None: + async def get_keys(key_to_fetch_item: _FetchKeyRequest) -> None: server_name = key_to_fetch_item.server_name - key_ids = key_to_fetch_item.key_ids try: - keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) + keys = await self.get_server_verify_keys_v2_direct(server_name) results[server_name] = keys except KeyLookupError as e: - logger.warning( - "Error looking up keys %s from %s: %s", key_ids, server_name, e - ) + logger.warning("Error looking up keys from %s: %s", server_name, e) except Exception: - logger.exception("Error getting keys %s from %s", key_ids, server_name) + logger.exception("Error getting keys from %s", server_name) - await yieldable_gather_results(get_key, keys_to_fetch) + await yieldable_gather_results(get_keys, keys_to_fetch) return results - async def get_server_verify_key_v2_direct( - self, server_name: str, key_ids: Iterable[str] + async def get_server_verify_keys_v2_direct( + self, server_name: str ) -> Dict[str, FetchKeyResult]: """ Args: - server_name: - key_ids: + server_name: Server to request keys from Returns: Map from key ID to lookup result @@ -845,57 +840,41 @@ class ServerKeyFetcher(BaseV2KeyFetcher): Raises: KeyLookupError if there was a problem making the lookup """ - keys: Dict[str, FetchKeyResult] = {} - - for requested_key_id in key_ids: - # we may have found this key as a side-effect of asking for another. - if requested_key_id in keys: - continue - - time_now_ms = self.clock.time_msec() - try: - response = await self.client.get_json( - destination=server_name, - path="/_matrix/key/v2/server/" - + urllib.parse.quote(requested_key_id, safe=""), - ignore_backoff=True, - # we only give the remote server 10s to respond. It should be an - # easy request to handle, so if it doesn't reply within 10s, it's - # probably not going to. - # - # Furthermore, when we are acting as a notary server, we cannot - # wait all day for all of the origin servers, as the requesting - # server will otherwise time out before we can respond. - # - # (Note that get_json may make 4 attempts, so this can still take - # almost 45 seconds to fetch the headers, plus up to another 60s to - # read the response). - timeout=10000, - ) - except (NotRetryingDestination, RequestSendFailed) as e: - # these both have str() representations which we can't really improve - # upon - raise KeyLookupError(str(e)) - except HttpResponseException as e: - raise KeyLookupError("Remote server returned an error: %s" % (e,)) - - assert isinstance(response, dict) - if response["server_name"] != server_name: - raise KeyLookupError( - "Expected a response for server %r not %r" - % (server_name, response["server_name"]) - ) - - response_keys = await self.process_v2_response( - from_server=server_name, - response_json=response, - time_added_ms=time_now_ms, + time_now_ms = self.clock.time_msec() + try: + response = await self.client.get_json( + destination=server_name, + path="/_matrix/key/v2/server", + ignore_backoff=True, + # we only give the remote server 10s to respond. It should be an + # easy request to handle, so if it doesn't reply within 10s, it's + # probably not going to. + # + # Furthermore, when we are acting as a notary server, we cannot + # wait all day for all of the origin servers, as the requesting + # server will otherwise time out before we can respond. + # + # (Note that get_json may make 4 attempts, so this can still take + # almost 45 seconds to fetch the headers, plus up to another 60s to + # read the response). + timeout=10000, ) - await self.store.store_server_verify_keys( - server_name, - time_now_ms, - ((server_name, key_id, key) for key_id, key in response_keys.items()), + except (NotRetryingDestination, RequestSendFailed) as e: + # these both have str() representations which we can't really improve + # upon + raise KeyLookupError(str(e)) + except HttpResponseException as e: + raise KeyLookupError("Remote server returned an error: %s" % (e,)) + + assert isinstance(response, dict) + if response["server_name"] != server_name: + raise KeyLookupError( + "Expected a response for server %r not %r" + % (server_name, response["server_name"]) ) - keys.update(response_keys) - return keys + return await self.process_v2_response( + from_server=server_name, + response_json=response, + time_added_ms=time_now_ms, + ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 63628aa6b0..f7c309cad0 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -433,7 +433,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): async def get_json(destination, path, **kwargs): self.assertEqual(destination, SERVER_NAME) - self.assertEqual(path, "/_matrix/key/v2/server/key1") + self.assertEqual(path, "/_matrix/key/v2/server") return response self.http_client.get_json.side_effect = get_json @@ -469,18 +469,6 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) self.assertEqual(keys, {}) - def test_keyid_containing_forward_slash(self) -> None: - """We should url-encode any url unsafe chars in key ids. - - Detects https://github.com/matrix-org/synapse/issues/14488. - """ - fetcher = ServerKeyFetcher(self.hs) - self.get_success(fetcher.get_keys("example.com", ["key/potato"], 0)) - - self.http_client.get_json.assert_called_once() - args, kwargs = self.http_client.get_json.call_args - self.assertEqual(kwargs["path"], "/_matrix/key/v2/server/key%2Fpotato") - class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 7f1fba1086..2bb6e27d94 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import urllib.parse from io import BytesIO, StringIO from typing import Any, Dict, Optional, Union from unittest.mock import Mock @@ -65,9 +64,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) - self.assertEqual( - path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),) - ) + self.assertEqual(path, "/_matrix/key/v2/server") response = { "server_name": server_name, -- cgit 1.4.1 From 60c3fea3271468dd1f9e9c5fae2d22dd9778293b Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 7 Dec 2022 17:35:41 +0000 Subject: Reject receipt requests with invalid room or event IDs. (#14632) If the room or event IDs are empty or of an invalid form they should be rejected. --- changelog.d/14632.bugfix | 1 + synapse/rest/client/receipts.py | 5 ++- tests/rest/client/test_receipts.py | 76 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14632.bugfix create mode 100644 tests/rest/client/test_receipts.py (limited to 'tests/rest') diff --git a/changelog.d/14632.bugfix b/changelog.d/14632.bugfix new file mode 100644 index 0000000000..323d10f1b0 --- /dev/null +++ b/changelog.d/14632.bugfix @@ -0,0 +1 @@ +Reject invalid read receipt requests with empty room or event IDs. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 18a282b22c..28b7d30ea8 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -20,7 +20,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.types import JsonDict +from synapse.types import EventID, JsonDict, RoomID from ._base import client_patterns @@ -56,6 +56,9 @@ class ReceiptRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) + if not RoomID.is_valid(room_id) or not event_id.startswith(EventID.SIGIL): + raise SynapseError(400, "A valid room ID and event ID must be specified") + if receipt_type not in self._known_receipt_types: raise SynapseError( 400, diff --git a/tests/rest/client/test_receipts.py b/tests/rest/client/test_receipts.py new file mode 100644 index 0000000000..2a7fcea386 --- /dev/null +++ b/tests/rest/client/test_receipts.py @@ -0,0 +1,76 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.rest.client import login, receipts, register +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest + + +class ReceiptsTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + register.register_servlets, + receipts.register_servlets, + synapse.rest.admin.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + + def test_send_receipt(self) -> None: + channel = self.make_request( + "POST", + "/rooms/!abc:beep/receipt/m.read/$def", + content={}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + def test_send_receipt_invalid_room_id(self) -> None: + channel = self.make_request( + "POST", + "/rooms/not-a-room-id/receipt/m.read/$def", + content={}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["error"], "A valid room ID and event ID must be specified" + ) + + def test_send_receipt_invalid_event_id(self) -> None: + channel = self.make_request( + "POST", + "/rooms/!abc:beep/receipt/m.read/not-an-event-id", + content={}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["error"], "A valid room ID and event ID must be specified" + ) + + def test_send_receipt_invalid_receipt_type(self) -> None: + channel = self.make_request( + "POST", + "/rooms/!abc:beep/receipt/invalid-receipt-type/$def", + content={}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) -- cgit 1.4.1 From 9d8a3234ba1d3ff831a7647f45c67946773d88a7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 8 Dec 2022 11:37:05 -0500 Subject: Respond with proper error responses on unknown paths. (#14621) Returns a proper 404 with an errcode of M_RECOGNIZED for unknown endpoints per MSC3743. --- changelog.d/14621.bugfix | 1 + synapse/api/errors.py | 6 ++---- synapse/http/server.py | 19 ++++++++++++++++++- synapse/rest/media/v1/media_repository.py | 4 ++-- synapse/util/httpresourcetree.py | 6 ++++-- tests/rest/admin/test_user.py | 2 +- tests/rest/client/test_login_token_request.py | 4 ++-- tests/rest/client/test_rendezvous.py | 2 +- tests/test_server.py | 2 +- 9 files changed, 32 insertions(+), 14 deletions(-) create mode 100644 changelog.d/14621.bugfix (limited to 'tests/rest') diff --git a/changelog.d/14621.bugfix b/changelog.d/14621.bugfix new file mode 100644 index 0000000000..cb95a87d92 --- /dev/null +++ b/changelog.d/14621.bugfix @@ -0,0 +1 @@ +Return spec-compliant JSON errors when unknown endpoints are requested. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index e2cfcea0f2..76ef12ed3a 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -300,10 +300,8 @@ class InteractiveAuthIncompleteError(Exception): class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" - def __init__( - self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED - ): - super().__init__(400, msg, errcode) + def __init__(self, msg: str = "Unrecognized request", code: int = 400): + super().__init__(code, msg, Codes.UNRECOGNIZED) class NotFoundError(SynapseError): diff --git a/synapse/http/server.py b/synapse/http/server.py index 051a1899a0..2563858f3c 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -577,7 +577,24 @@ def _unrecognised_request_handler(request: Request) -> NoReturn: Args: request: Unused, but passed in to match the signature of ServletCallback. """ - raise UnrecognizedRequestError() + raise UnrecognizedRequestError(code=404) + + +class UnrecognizedRequestResource(resource.Resource): + """ + Similar to twisted.web.resource.NoResource, but returns a JSON 404 with an + errcode of M_UNRECOGNIZED. + """ + + def render(self, request: SynapseRequest) -> int: + f = failure.Failure(UnrecognizedRequestError(code=404)) + return_json_error(f, request, None) + # A response has already been sent but Twisted requires either NOT_DONE_YET + # or the response bytes as a return value. + return NOT_DONE_YET + + def getChild(self, name: str, request: Request) -> resource.Resource: + return self class RootRedirect(resource.Resource): diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 40b0d39eb2..c70e1837af 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -24,7 +24,6 @@ from matrix_common.types.mxc_uri import MXCUri import twisted.internet.error import twisted.web.http from twisted.internet.defer import Deferred -from twisted.web.resource import Resource from synapse.api.errors import ( FederationDeniedError, @@ -35,6 +34,7 @@ from synapse.api.errors import ( ) from synapse.config._base import ConfigError from synapse.config.repository import ThumbnailRequirement +from synapse.http.server import UnrecognizedRequestResource from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process @@ -1046,7 +1046,7 @@ class MediaRepository: return removed_media, len(removed_media) -class MediaRepositoryResource(Resource): +class MediaRepositoryResource(UnrecognizedRequestResource): """File uploading and downloading. Uploads are POSTed to a resource which returns a token which is used to GET diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index a0606851f7..39fab4fe06 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -15,7 +15,9 @@ import logging from typing import Dict -from twisted.web.resource import NoResource, Resource +from twisted.web.resource import Resource + +from synapse.http.server import UnrecognizedRequestResource logger = logging.getLogger(__name__) @@ -49,7 +51,7 @@ def create_resource_tree( for path_seg in full_path.split(b"/")[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" - child_resource: Resource = NoResource() + child_resource: Resource = UnrecognizedRequestResource() last_resource.putChild(path_seg, child_resource) res_id = _resource_id(last_resource, path_seg) resource_mappings[res_id] = child_resource diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e8c9457794..5c1ced355f 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -3994,7 +3994,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): """ Tests that shadow-banning for a user that is not a local returns a 400 """ - url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" + url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/shadow_ban" channel = self.make_request(method, url, access_token=self.admin_user_tok) self.assertEqual(400, channel.code, msg=channel.json_body) diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py index c2e1e08811..6aedc1a11c 100644 --- a/tests/rest/client/test_login_token_request.py +++ b/tests/rest/client/test_login_token_request.py @@ -48,13 +48,13 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): def test_disabled(self) -> None: channel = self.make_request("POST", endpoint, {}, access_token=None) - self.assertEqual(channel.code, 400) + self.assertEqual(channel.code, 404) self.register_user(self.user, self.password) token = self.login(self.user, self.password) channel = self.make_request("POST", endpoint, {}, access_token=token) - self.assertEqual(channel.code, 400) + self.assertEqual(channel.code, 404) @override_config({"experimental_features": {"msc3882_enabled": True}}) def test_require_auth(self) -> None: diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py index ad00a476e1..c0eb5d01a6 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py @@ -36,7 +36,7 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase): def test_disabled(self) -> None: channel = self.make_request("POST", endpoint, {}, access_token=None) - self.assertEqual(channel.code, 400) + self.assertEqual(channel.code, 404) @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}}) def test_redirect(self) -> None: diff --git a/tests/test_server.py b/tests/test_server.py index 2d9a0257d4..d67d7722a4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -174,7 +174,7 @@ class JsonResourceTests(unittest.TestCase): self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar" ) - self.assertEqual(channel.code, 400) + self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") -- cgit 1.4.1