From 4c2096599c9780290703e14df63963e77d058dda Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 21 Jan 2022 08:38:36 +0000 Subject: Make the `get_global_account_data_by_type_for_user` cache be a tree-cache whose key is prefixed with the user ID (#11788) --- synapse/rest/client/account_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/rest') diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index d1badbdf3b..58b8adbd32 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -66,7 +66,7 @@ class AccountDataServlet(RestServlet): raise AuthError(403, "Cannot get account data for other users.") event = await self.store.get_global_account_data_by_type_for_user( - account_data_type, user_id + user_id, account_data_type ) if event is None: -- cgit 1.5.1 From b784299cbc121d27d7dadd0a4a96f4657244a4e9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Jan 2022 05:31:31 -0500 Subject: Do not try to serialize raw aggregations dict. (#11791) --- changelog.d/11612.bugfix | 1 + changelog.d/11612.misc | 1 - changelog.d/11791.bugfix | 1 + synapse/events/utils.py | 4 +- synapse/rest/admin/rooms.py | 13 ++--- synapse/rest/client/room.py | 11 ++-- tests/rest/client/test_relations.py | 108 ++++++++++++++++++++++++------------ 7 files changed, 85 insertions(+), 54 deletions(-) create mode 100644 changelog.d/11612.bugfix delete mode 100644 changelog.d/11612.misc create mode 100644 changelog.d/11791.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/11612.bugfix b/changelog.d/11612.bugfix new file mode 100644 index 0000000000..842f6892fd --- /dev/null +++ b/changelog.d/11612.bugfix @@ -0,0 +1 @@ +Include the bundled aggregations in the `/sync` response, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/changelog.d/11612.misc b/changelog.d/11612.misc deleted file mode 100644 index 2d886169c5..0000000000 --- a/changelog.d/11612.misc +++ /dev/null @@ -1 +0,0 @@ -Avoid database access in the JSON serialization process. diff --git a/changelog.d/11791.bugfix b/changelog.d/11791.bugfix new file mode 100644 index 0000000000..842f6892fd --- /dev/null +++ b/changelog.d/11791.bugfix @@ -0,0 +1 @@ +Include the bundled aggregations in the `/sync` response, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/synapse/events/utils.py b/synapse/events/utils.py index de0e0c1731..918adeecf8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -402,7 +402,7 @@ class EventClientSerializer: if bundle_aggregations: event_aggregations = bundle_aggregations.get(event.event_id) if event_aggregations: - self._injected_bundled_aggregations( + self._inject_bundled_aggregations( event, time_now, bundle_aggregations[event.event_id], @@ -411,7 +411,7 @@ class EventClientSerializer: return serialized_event - def _injected_bundled_aggregations( + def _inject_bundled_aggregations( self, event: EventBase, time_now: int, diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 2e714ac87b..efe25fe7eb 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -744,20 +744,15 @@ class RoomEventContextServlet(RestServlet): ) time_now = self.clock.time_msec() + aggregations = results.pop("aggregations", None) results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], - time_now, - bundle_aggregations=results["aggregations"], + results["events_before"], time_now, bundle_aggregations=aggregations ) results["event"] = self._event_serializer.serialize_event( - results["event"], - time_now, - bundle_aggregations=results["aggregations"], + results["event"], time_now, bundle_aggregations=aggregations ) results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], - time_now, - bundle_aggregations=results["aggregations"], + results["events_after"], time_now, bundle_aggregations=aggregations ) results["state"] = self._event_serializer.serialize_events( results["state"], time_now diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 31fd329a38..90bb9142a0 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -714,18 +714,15 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() + aggregations = results.pop("aggregations", None) results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], - time_now, - bundle_aggregations=results["aggregations"], + results["events_before"], time_now, bundle_aggregations=aggregations ) results["event"] = self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=results["aggregations"] + results["event"], time_now, bundle_aggregations=aggregations ) results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], - time_now, - bundle_aggregations=results["aggregations"], + results["events_after"], time_now, bundle_aggregations=aggregations ) results["state"] = self._event_serializer.serialize_events( results["state"], time_now diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 4b20ab0e3e..c9b220e73d 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -21,6 +21,7 @@ from unittest.mock import patch from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync +from synapse.types import JsonDict from tests import unittest from tests.server import FakeChannel @@ -454,7 +455,14 @@ class RelationsTestCase(unittest.HomeserverTestCase): @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_bundled_aggregations(self): - """Test that annotations, references, and threads get correctly bundled.""" + """ + Test that annotations, references, and threads get correctly bundled. + + Note that this doesn't test against /relations since only thread relations + get bundled via that API. See test_aggregation_get_event_for_thread. + + See test_edit for a similar test for edits. + """ # Setup by sending a variety of relations. channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -482,12 +490,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] - def assert_bundle(actual): + def assert_bundle(event_json: JsonDict) -> None: """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") # Ensure the fields are as expected. self.assertCountEqual( - actual.keys(), + relations_dict.keys(), ( RelationTypes.ANNOTATION, RelationTypes.REFERENCE, @@ -503,20 +512,20 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"type": "m.reaction", "key": "b", "count": 1}, ] }, - actual[RelationTypes.ANNOTATION], + relations_dict[RelationTypes.ANNOTATION], ) self.assertEquals( {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - actual[RelationTypes.REFERENCE], + relations_dict[RelationTypes.REFERENCE], ) self.assertEquals( 2, - actual[RelationTypes.THREAD].get("count"), + relations_dict[RelationTypes.THREAD].get("count"), ) self.assertTrue( - actual[RelationTypes.THREAD].get("current_user_participated") + relations_dict[RelationTypes.THREAD].get("current_user_participated") ) # The latest thread event has some fields that don't matter. self.assert_dict( @@ -533,20 +542,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): "type": "m.room.test", "user_id": self.user_id, }, - actual[RelationTypes.THREAD].get("latest_event"), + relations_dict[RelationTypes.THREAD].get("latest_event"), ) - def _find_and_assert_event(events): - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - break - else: - raise AssertionError(f"Event {self.parent_id} not found in chunk") - assert_bundle(event["unsigned"].get("m.relations")) - # Request the event directly. channel = self.make_request( "GET", @@ -554,7 +552,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["unsigned"].get("m.relations")) + assert_bundle(channel.json_body) # Request the room messages. channel = self.make_request( @@ -563,7 +561,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - _find_and_assert_event(channel.json_body["chunk"]) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) # Request the room context. channel = self.make_request( @@ -572,17 +570,14 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) + assert_bundle(channel.json_body["event"]) # Request sync. channel = self.make_request("GET", "/sync", access_token=self.user_token) self.assertEquals(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) - _find_and_assert_event(room_timeline["events"]) - - # Note that /relations is tested separately in test_aggregation_get_event_for_thread - # since it needs different data configured. + self._find_event_in_chunk(room_timeline["events"]) def test_aggregation_get_event_for_annotation(self): """Test that annotations do not get bundled aggregations included @@ -777,25 +772,58 @@ class RelationsTestCase(unittest.HomeserverTestCase): edit_event_id = channel.json_body["event_id"] + def assert_bundle(event_json: JsonDict) -> None: + """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), + f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["content"], new_body) + assert_bundle(channel.json_body) - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]) - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + # Request sync, but limit the timeline so it becomes limited (and includes + # bundled aggregations). + filter = urllib.parse.quote_plus( + '{"room": {"timeline": {"limit": 2}}}'.encode() + ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token ) + self.assertEquals(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) def test_multi_edit(self): """Test that multiple edits, including attempts by people who @@ -1102,6 +1130,16 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(channel.json_body["chunk"], []) + def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + return event + + raise AssertionError(f"Event {self.parent_id} not found in chunk") + def _send_relation( self, relation_type: str, -- cgit 1.5.1 From 807efd26aec9b65c6a2f02d10fd139095a5b3387 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 Jan 2022 08:58:18 -0500 Subject: Support rendering previews with data: URLs in them (#11767) Images which are data URLs will no longer break URL previews and will properly be "downloaded" and thumbnailed. --- changelog.d/11767.bugfix | 1 + synapse/rest/media/v1/preview_html.py | 31 +- synapse/rest/media/v1/preview_url_resource.py | 224 ++++++++---- tests/rest/media/v1/test_html_preview.py | 481 ++++++++++++++++++++++++++ tests/rest/media/v1/test_url_preview.py | 81 ++++- tests/server.py | 2 +- tests/test_preview.py | 449 ------------------------ 7 files changed, 747 insertions(+), 522 deletions(-) create mode 100644 changelog.d/11767.bugfix create mode 100644 tests/rest/media/v1/test_html_preview.py delete mode 100644 tests/test_preview.py (limited to 'synapse/rest') diff --git a/changelog.d/11767.bugfix b/changelog.d/11767.bugfix new file mode 100644 index 0000000000..3e344747f4 --- /dev/null +++ b/changelog.d/11767.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when previewing Reddit URLs which do not contain an image. diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py index 30b067dd42..872a9e72e8 100644 --- a/synapse/rest/media/v1/preview_html.py +++ b/synapse/rest/media/v1/preview_html.py @@ -321,14 +321,33 @@ def _iterate_over_text( def rebase_url(url: str, base: str) -> str: - base_parts = list(urlparse.urlparse(base)) + """ + Resolves a potentially relative `url` against an absolute `base` URL. + + For example: + + >>> rebase_url("subpage", "https://example.com/foo/") + 'https://example.com/foo/subpage' + >>> rebase_url("sibling", "https://example.com/foo") + 'https://example.com/sibling' + >>> rebase_url("/bar", "https://example.com/foo/") + 'https://example.com/bar' + >>> rebase_url("https://alice.com/a/", "https://example.com/foo/") + 'https://alice.com/a' + """ + base_parts = urlparse.urlparse(base) + # Convert the parsed URL to a list for (potential) modification. url_parts = list(urlparse.urlparse(url)) - if not url_parts[0]: # fix up schema - url_parts[0] = base_parts[0] or "http" - if not url_parts[1]: # fix up hostname - url_parts[1] = base_parts[1] + # Add a scheme, if one does not exist. + if not url_parts[0]: + url_parts[0] = base_parts.scheme or "http" + # Fix up the hostname, if this is not a data URL. + if url_parts[0] != "data" and not url_parts[1]: + url_parts[1] = base_parts.netloc + # If the path does not start with a /, nest it under the base path's last + # directory. if not url_parts[2].startswith("/"): - url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2] + url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2] return urlparse.urlunparse(url_parts) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index e8881bc870..efd84ced8f 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -21,8 +21,9 @@ import re import shutil import sys import traceback -from typing import TYPE_CHECKING, Iterable, Optional, Tuple +from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple from urllib import parse as urlparse +from urllib.request import urlopen import attr @@ -70,6 +71,17 @@ ONE_DAY = 24 * ONE_HOUR IMAGE_CACHE_EXPIRY_MS = 2 * ONE_DAY +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DownloadResult: + length: int + uri: str + response_code: int + media_type: str + download_name: Optional[str] + expires: int + etag: Optional[str] + + @attr.s(slots=True, frozen=True, auto_attribs=True) class MediaInfo: """ @@ -256,7 +268,7 @@ class PreviewUrlResource(DirectServeJsonResource): if oembed_url: url_to_download = oembed_url - media_info = await self._download_url(url_to_download, user) + media_info = await self._handle_url(url_to_download, user) logger.debug("got media_info of '%s'", media_info) @@ -297,7 +309,9 @@ class PreviewUrlResource(DirectServeJsonResource): oembed_url = self._oembed.autodiscover_from_html(tree) og_from_oembed: JsonDict = {} if oembed_url: - oembed_info = await self._download_url(oembed_url, user) + oembed_info = await self._handle_url( + oembed_url, user, allow_data_urls=True + ) ( og_from_oembed, author_name, @@ -367,7 +381,135 @@ class PreviewUrlResource(DirectServeJsonResource): return jsonog.encode("utf8") - async def _download_url(self, url: str, user: UserID) -> MediaInfo: + async def _download_url(self, url: str, output_stream: BinaryIO) -> DownloadResult: + """ + Fetches a remote URL and parses the headers. + + Args: + url: The URL to fetch. + output_stream: The stream to write the content to. + + Returns: + A tuple of: + Media length, URL downloaded, the HTTP response code, + the media type, the downloaded file name, the number of + milliseconds the result is valid for, the etag header. + """ + + try: + logger.debug("Trying to get preview for url '%s'", url) + length, headers, uri, code = await self.client.get_file( + url, + output_stream=output_stream, + max_size=self.max_spider_size, + headers={"Accept-Language": self.url_preview_accept_language}, + ) + except SynapseError: + # Pass SynapseErrors through directly, so that the servlet + # handler will return a SynapseError to the client instead of + # blank data or a 500. + raise + except DNSLookupError: + # DNS lookup returned no results + # Note: This will also be the case if one of the resolved IP + # addresses is blacklisted + raise SynapseError( + 502, + "DNS resolution failure during URL preview generation", + Codes.UNKNOWN, + ) + except Exception as e: + # FIXME: pass through 404s and other error messages nicely + logger.warning("Error downloading %s: %r", url, e) + + raise SynapseError( + 500, + "Failed to download content: %s" + % (traceback.format_exception_only(sys.exc_info()[0], e),), + Codes.UNKNOWN, + ) + + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode("ascii") + else: + media_type = "application/octet-stream" + + download_name = get_filename_from_headers(headers) + + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + expires = ONE_HOUR + etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None + + return DownloadResult( + length, uri, code, media_type, download_name, expires, etag + ) + + async def _parse_data_url( + self, url: str, output_stream: BinaryIO + ) -> DownloadResult: + """ + Parses a data: URL. + + Args: + url: The URL to parse. + output_stream: The stream to write the content to. + + Returns: + A tuple of: + Media length, URL downloaded, the HTTP response code, + the media type, the downloaded file name, the number of + milliseconds the result is valid for, the etag header. + """ + + try: + logger.debug("Trying to parse data url '%s'", url) + with urlopen(url) as url_info: + # TODO Can this be more efficient. + output_stream.write(url_info.read()) + except Exception as e: + logger.warning("Error parsing data: URL %s: %r", url, e) + + raise SynapseError( + 500, + "Failed to parse data URL: %s" + % (traceback.format_exception_only(sys.exc_info()[0], e),), + Codes.UNKNOWN, + ) + + return DownloadResult( + # Read back the length that has been written. + length=output_stream.tell(), + uri=url, + # If it was parsed, consider this a 200 OK. + response_code=200, + # urlopen shoves the media-type from the data URL into the content type + # header object. + media_type=url_info.headers.get_content_type(), + # Some features are not supported by data: URLs. + download_name=None, + expires=ONE_HOUR, + etag=None, + ) + + async def _handle_url( + self, url: str, user: UserID, allow_data_urls: bool = False + ) -> MediaInfo: + """ + Fetches content from a URL and parses the result to generate a MediaInfo. + + It uses the media storage provider to persist the fetched content and + stores the mapping into the database. + + Args: + url: The URL to fetch. + user: The user who ahs requested this URL. + allow_data_urls: True if data URLs should be allowed. + + Returns: + A MediaInfo object describing the fetched content. + """ + # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -377,61 +519,27 @@ class PreviewUrlResource(DirectServeJsonResource): file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) with self.media_storage.store_into_file(file_info) as (f, fname, finish): - try: - logger.debug("Trying to get preview for url '%s'", url) - length, headers, uri, code = await self.client.get_file( - url, - output_stream=f, - max_size=self.max_spider_size, - headers={"Accept-Language": self.url_preview_accept_language}, - ) - except SynapseError: - # Pass SynapseErrors through directly, so that the servlet - # handler will return a SynapseError to the client instead of - # blank data or a 500. - raise - except DNSLookupError: - # DNS lookup returned no results - # Note: This will also be the case if one of the resolved IP - # addresses is blacklisted - raise SynapseError( - 502, - "DNS resolution failure during URL preview generation", - Codes.UNKNOWN, - ) - except Exception as e: - # FIXME: pass through 404s and other error messages nicely - logger.warning("Error downloading %s: %r", url, e) - - raise SynapseError( - 500, - "Failed to download content: %s" - % (traceback.format_exception_only(sys.exc_info()[0], e),), - Codes.UNKNOWN, - ) - await finish() + if url.startswith("data:"): + if not allow_data_urls: + raise SynapseError( + 500, "Previewing of data: URLs is forbidden", Codes.UNKNOWN + ) - if b"Content-Type" in headers: - media_type = headers[b"Content-Type"][0].decode("ascii") + download_result = await self._parse_data_url(url, f) else: - media_type = "application/octet-stream" + download_result = await self._download_url(url, f) - download_name = get_filename_from_headers(headers) - - # FIXME: we should calculate a proper expiration based on the - # Cache-Control and Expire headers. But for now, assume 1 hour. - expires = ONE_HOUR - etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None + await finish() try: time_now_ms = self.clock.time_msec() await self.store.store_local_media( media_id=file_id, - media_type=media_type, + media_type=download_result.media_type, time_now_ms=time_now_ms, - upload_name=download_name, - media_length=length, + upload_name=download_result.download_name, + media_length=download_result.length, user_id=user, url_cache=url, ) @@ -444,16 +552,16 @@ class PreviewUrlResource(DirectServeJsonResource): raise return MediaInfo( - media_type=media_type, - media_length=length, - download_name=download_name, + media_type=download_result.media_type, + media_length=download_result.length, + download_name=download_result.download_name, created_ts_ms=time_now_ms, filesystem_id=file_id, filename=fname, - uri=uri, - response_code=code, - expires=expires, - etag=etag, + uri=download_result.uri, + response_code=download_result.response_code, + expires=download_result.expires, + etag=download_result.etag, ) async def _precache_image_url( @@ -474,8 +582,8 @@ class PreviewUrlResource(DirectServeJsonResource): # FIXME: it might be cleaner to use the same flow as the main /preview_url # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. - image_info = await self._download_url( - rebase_url(og["og:image"], media_info.uri), user + image_info = await self._handle_url( + rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True ) if _is_media(image_info.media_type): diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py new file mode 100644 index 0000000000..a4b57e3d1f --- /dev/null +++ b/tests/rest/media/v1/test_html_preview.py @@ -0,0 +1,481 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest.media.v1.preview_html import ( + _get_html_media_encodings, + decode_body, + parse_html_to_open_graph, + rebase_url, + summarize_paragraphs, +) + +from tests import unittest + +try: + import lxml +except ImportError: + lxml = None + + +class SummarizeTestCase(unittest.TestCase): + if not lxml: + skip = "url preview feature requires lxml" + + def test_long_summarize(self): + example_paras = [ + """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: + Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in + Troms county, Norway. The administrative centre of the municipality is + the city of Tromsø. Outside of Norway, Tromso and Tromsö are + alternative spellings of the city.Tromsø is considered the northernmost + city in the world with a population above 50,000. The most populous town + north of it is Alta, Norway, with a population of 14,272 (2013).""", + """Tromsø lies in Northern Norway. The municipality has a population of + (2015) 72,066, but with an annual influx of students it has over 75,000 + most of the year. It is the largest urban area in Northern Norway and the + third largest north of the Arctic Circle (following Murmansk and Norilsk). + Most of Tromsø, including the city centre, is located on the island of + Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012, + Tromsøya had a population of 36,088. Substantial parts of the urban area + are also situated on the mainland to the east, and on parts of Kvaløya—a + large island to the west. Tromsøya is connected to the mainland by the Tromsø + Bridge and the Tromsøysund Tunnel, and to the island of Kvaløya by the + Sandnessund Bridge. Tromsø Airport connects the city to many destinations + in Europe. The city is warmer than most other places located on the same + latitude, due to the warming effect of the Gulf Stream.""", + """The city centre of Tromsø contains the highest number of old wooden + houses in Northern Norway, the oldest house dating from 1789. The Arctic + Cathedral, a modern church from 1965, is probably the most famous landmark + in Tromsø. The city is a cultural centre for its region, with several + festivals taking place in the summer. Some of Norway's best-known + musicians, Torbjørn Brundtland and Svein Berge of the electronica duo + Röyksopp and Lene Marlin grew up and started their careers in Tromsø. + Noted electronic musician Geir Jenssen also hails from Tromsø.""", + ] + + desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) + + self.assertEqual( + desc, + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway. The administrative centre of the municipality is" + " the city of Tromsø. Outside of Norway, Tromso and Tromsö are" + " alternative spellings of the city.Tromsø is considered the northernmost" + " city in the world with a population above 50,000. The most populous town" + " north of it is Alta, Norway, with a population of 14,272 (2013).", + ) + + desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) + + self.assertEqual( + desc, + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year. It is the largest urban area in Northern Norway and the" + " third largest north of the Arctic Circle (following Murmansk and Norilsk)." + " Most of Tromsø, including the city centre, is located on the island of" + " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," + " Tromsøya had a population of 36,088. Substantial parts of the urban…", + ) + + def test_short_summarize(self): + example_paras = [ + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.", + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year.", + "The city centre of Tromsø contains the highest number of old wooden" + " houses in Northern Norway, the oldest house dating from 1789. The Arctic" + " Cathedral, a modern church from 1965, is probably the most famous landmark" + " in Tromsø.", + ] + + desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) + + self.assertEqual( + desc, + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.\n" + "\n" + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year.", + ) + + def test_small_then_large_summarize(self): + example_paras = [ + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.", + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year." + " The city centre of Tromsø contains the highest number of old wooden" + " houses in Northern Norway, the oldest house dating from 1789. The Arctic" + " Cathedral, a modern church from 1965, is probably the most famous landmark" + " in Tromsø.", + ] + + desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) + self.assertEqual( + desc, + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.\n" + "\n" + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year. The city centre of Tromsø contains the highest number" + " of old wooden houses in Northern Norway, the oldest house dating from" + " 1789. The Arctic Cathedral, a modern church from…", + ) + + +class CalcOgTestCase(unittest.TestCase): + if not lxml: + skip = "url preview feature requires lxml" + + def test_simple(self): + html = b""" + + Foo + + Some text. + + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) + + def test_comment(self): + html = b""" + + Foo + + + Some text. + + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) + + def test_comment2(self): + html = b""" + + Foo + + Some text. + + Some more text. +

Text

+ More text + + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual( + og, + { + "og:title": "Foo", + "og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text", + }, + ) + + def test_script(self): + html = b""" + + Foo + + + Some text. + + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) + + def test_missing_title(self): + html = b""" + + + Some text. + + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) + + def test_h1_as_title(self): + html = b""" + + + +

Title

+ + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) + + def test_missing_title_and_broken_h1(self): + html = b""" + + +

+ Some text. + + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + + self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) + + def test_empty(self): + """Test a body with no data in it.""" + html = b"" + tree = decode_body(html, "http://example.com/test.html") + self.assertIsNone(tree) + + def test_no_tree(self): + """A valid body with no tree in it.""" + html = b"\x00" + tree = decode_body(html, "http://example.com/test.html") + self.assertIsNone(tree) + + def test_xml(self): + """Test decoding XML and ensure it works properly.""" + # Note that the strip() call is important to ensure the xml tag starts + # at the initial byte. + html = b""" + + + + + FooSome text. + """.strip() + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) + + def test_invalid_encoding(self): + """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" + html = b""" + + Foo + + Some text. + + + """ + tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) + + def test_invalid_encoding2(self): + """A body which doesn't match the sent character encoding.""" + # Note that this contains an invalid UTF-8 sequence in the title. + html = b""" + + \xff\xff Foo + + Some text. + + + """ + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) + + def test_windows_1252(self): + """A body which uses cp1252, but doesn't declare that.""" + html = b""" + + \xf3 + + Some text. + + + """ + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree, "http://example.com/test.html") + self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) + + +class MediaEncodingTestCase(unittest.TestCase): + def test_meta_charset(self): + """A character encoding is found via the meta tag.""" + encodings = _get_html_media_encodings( + b""" + + + + + """, + "text/html", + ) + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) + + # A less well-formed version. + encodings = _get_html_media_encodings( + b""" + + < meta charset = ascii> + + + """, + "text/html", + ) + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) + + def test_meta_charset_underscores(self): + """A character encoding contains underscore.""" + encodings = _get_html_media_encodings( + b""" + + + + + """, + "text/html", + ) + self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"]) + + def test_xml_encoding(self): + """A character encoding is found via the meta tag.""" + encodings = _get_html_media_encodings( + b""" + + + + """, + "text/html", + ) + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) + + def test_meta_xml_encoding(self): + """Meta tags take precedence over XML encoding.""" + encodings = _get_html_media_encodings( + b""" + + + + + + """, + "text/html", + ) + self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"]) + + def test_content_type(self): + """A character encoding is found via the Content-Type header.""" + # Test a few variations of the header. + headers = ( + 'text/html; charset="ascii";', + "text/html;charset=ascii;", + 'text/html; charset="ascii"', + "text/html; charset=ascii", + 'text/html; charset="ascii;', + 'text/html; charset=ascii";', + ) + for header in headers: + encodings = _get_html_media_encodings(b"", header) + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) + + def test_fallback(self): + """A character encoding cannot be found in the body or header.""" + encodings = _get_html_media_encodings(b"", "text/html") + self.assertEqual(list(encodings), ["utf-8", "cp1252"]) + + def test_duplicates(self): + """Ensure each encoding is only attempted once.""" + encodings = _get_html_media_encodings( + b""" + + + + + + """, + 'text/html; charset="UTF_8"', + ) + self.assertEqual(list(encodings), ["utf-8", "cp1252"]) + + def test_unknown_invalid(self): + """A character encoding should be ignored if it is unknown or invalid.""" + encodings = _get_html_media_encodings( + b""" + + + + + """, + 'text/html; charset="invalid"', + ) + self.assertEqual(list(encodings), ["utf-8", "cp1252"]) + + +class RebaseUrlTestCase(unittest.TestCase): + def test_relative(self): + """Relative URLs should be resolved based on the context of the base URL.""" + self.assertEqual( + rebase_url("subpage", "https://example.com/foo/"), + "https://example.com/foo/subpage", + ) + self.assertEqual( + rebase_url("sibling", "https://example.com/foo"), + "https://example.com/sibling", + ) + self.assertEqual( + rebase_url("/bar", "https://example.com/foo/"), + "https://example.com/bar", + ) + + def test_absolute(self): + """Absolute URLs should not be modified.""" + self.assertEqual( + rebase_url("https://alice.com/a/", "https://example.com/foo/"), + "https://alice.com/a/", + ) + + def test_data(self): + """Data URLs should not be modified.""" + self.assertEqual( + rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"), + "data:,Hello%2C%20World%21", + ) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 16e904f15b..53f6186213 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -12,9 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import json import os import re +from urllib.parse import urlencode from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address @@ -23,6 +25,7 @@ from twisted.test.proto_helpers import AccumulatingProtocol from synapse.config.oembed import OEmbedEndpointConfig from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS +from synapse.types import JsonDict from synapse.util.stringutils import parse_and_validate_mxc_uri from tests import unittest @@ -142,6 +145,14 @@ class URLPreviewTests(unittest.HomeserverTestCase): def create_test_resource(self): return self.hs.get_media_repository_resource() + def _assert_small_png(self, json_body: JsonDict) -> None: + """Assert properties from the SMALL_PNG test image.""" + self.assertTrue(json_body["og:image"].startswith("mxc://")) + self.assertEqual(json_body["og:image:height"], 1) + self.assertEqual(json_body["og:image:width"], 1) + self.assertEqual(json_body["og:image:type"], "image/png") + self.assertEqual(json_body["matrix:image:size"], 67) + def test_cache_returns_correct_type(self): self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] @@ -569,6 +580,66 @@ class URLPreviewTests(unittest.HomeserverTestCase): server.data, ) + def test_data_url(self): + """ + Requesting to preview a data URL is not supported. + """ + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + data = base64.b64encode(SMALL_PNG).decode() + + query_params = urlencode( + { + "url": f'' + } + ) + + channel = self.make_request( + "GET", + f"preview_url?{query_params}", + shorthand=False, + ) + self.pump() + + self.assertEqual(channel.code, 500) + + def test_inline_data_url(self): + """ + An inline image (as a data URL) should be parsed properly. + """ + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + data = base64.b64encode(SMALL_PNG) + + end_content = ( + b"" b'' b"" + ) % (data,) + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + self._assert_small_png(channel.json_body) + def test_oembed_photo(self): """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL.""" self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] @@ -626,10 +697,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) body = channel.json_body self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345") - self.assertTrue(body["og:image"].startswith("mxc://")) - self.assertEqual(body["og:image:height"], 1) - self.assertEqual(body["og:image:width"], 1) - self.assertEqual(body["og:image:type"], "image/png") + self._assert_small_png(body) def test_oembed_rich(self): """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" @@ -820,10 +888,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( body["og:url"], "http://www.twitter.com/matrixdotorg/status/12345" ) - self.assertTrue(body["og:image"].startswith("mxc://")) - self.assertEqual(body["og:image:height"], 1) - self.assertEqual(body["og:image:width"], 1) - self.assertEqual(body["og:image:type"], "image/png") + self._assert_small_png(body) def _download_image(self): """Downloads an image into the URL cache. diff --git a/tests/server.py b/tests/server.py index a0cd14ea45..82990c2eb9 100644 --- a/tests/server.py +++ b/tests/server.py @@ -313,7 +313,7 @@ def make_request( req = request(channel, site) req.content = BytesIO(content) # Twisted expects to be at the end of the content when parsing the request. - req.content.seek(SEEK_END) + req.content.seek(0, SEEK_END) if access_token: req.requestHeaders.addRawHeader( diff --git a/tests/test_preview.py b/tests/test_preview.py deleted file mode 100644 index 46e02f483f..0000000000 --- a/tests/test_preview.py +++ /dev/null @@ -1,449 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.rest.media.v1.preview_html import ( - _get_html_media_encodings, - decode_body, - parse_html_to_open_graph, - summarize_paragraphs, -) - -from . import unittest - -try: - import lxml -except ImportError: - lxml = None - - -class SummarizeTestCase(unittest.TestCase): - if not lxml: - skip = "url preview feature requires lxml" - - def test_long_summarize(self): - example_paras = [ - """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: - Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in - Troms county, Norway. The administrative centre of the municipality is - the city of Tromsø. Outside of Norway, Tromso and Tromsö are - alternative spellings of the city.Tromsø is considered the northernmost - city in the world with a population above 50,000. The most populous town - north of it is Alta, Norway, with a population of 14,272 (2013).""", - """Tromsø lies in Northern Norway. The municipality has a population of - (2015) 72,066, but with an annual influx of students it has over 75,000 - most of the year. It is the largest urban area in Northern Norway and the - third largest north of the Arctic Circle (following Murmansk and Norilsk). - Most of Tromsø, including the city centre, is located on the island of - Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012, - Tromsøya had a population of 36,088. Substantial parts of the urban area - are also situated on the mainland to the east, and on parts of Kvaløya—a - large island to the west. Tromsøya is connected to the mainland by the Tromsø - Bridge and the Tromsøysund Tunnel, and to the island of Kvaløya by the - Sandnessund Bridge. Tromsø Airport connects the city to many destinations - in Europe. The city is warmer than most other places located on the same - latitude, due to the warming effect of the Gulf Stream.""", - """The city centre of Tromsø contains the highest number of old wooden - houses in Northern Norway, the oldest house dating from 1789. The Arctic - Cathedral, a modern church from 1965, is probably the most famous landmark - in Tromsø. The city is a cultural centre for its region, with several - festivals taking place in the summer. Some of Norway's best-known - musicians, Torbjørn Brundtland and Svein Berge of the electronica duo - Röyksopp and Lene Marlin grew up and started their careers in Tromsø. - Noted electronic musician Geir Jenssen also hails from Tromsø.""", - ] - - desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - - self.assertEqual( - desc, - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway. The administrative centre of the municipality is" - " the city of Tromsø. Outside of Norway, Tromso and Tromsö are" - " alternative spellings of the city.Tromsø is considered the northernmost" - " city in the world with a population above 50,000. The most populous town" - " north of it is Alta, Norway, with a population of 14,272 (2013).", - ) - - desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) - - self.assertEqual( - desc, - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year. It is the largest urban area in Northern Norway and the" - " third largest north of the Arctic Circle (following Murmansk and Norilsk)." - " Most of Tromsø, including the city centre, is located on the island of" - " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," - " Tromsøya had a population of 36,088. Substantial parts of the urban…", - ) - - def test_short_summarize(self): - example_paras = [ - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.", - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year.", - "The city centre of Tromsø contains the highest number of old wooden" - " houses in Northern Norway, the oldest house dating from 1789. The Arctic" - " Cathedral, a modern church from 1965, is probably the most famous landmark" - " in Tromsø.", - ] - - desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - - self.assertEqual( - desc, - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.\n" - "\n" - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year.", - ) - - def test_small_then_large_summarize(self): - example_paras = [ - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.", - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year." - " The city centre of Tromsø contains the highest number of old wooden" - " houses in Northern Norway, the oldest house dating from 1789. The Arctic" - " Cathedral, a modern church from 1965, is probably the most famous landmark" - " in Tromsø.", - ] - - desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - self.assertEqual( - desc, - "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - " Troms county, Norway.\n" - "\n" - "Tromsø lies in Northern Norway. The municipality has a population of" - " (2015) 72,066, but with an annual influx of students it has over 75,000" - " most of the year. The city centre of Tromsø contains the highest number" - " of old wooden houses in Northern Norway, the oldest house dating from" - " 1789. The Arctic Cathedral, a modern church from…", - ) - - -class CalcOgTestCase(unittest.TestCase): - if not lxml: - skip = "url preview feature requires lxml" - - def test_simple(self): - html = b""" - - Foo - - Some text. - - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_comment(self): - html = b""" - - Foo - - - Some text. - - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_comment2(self): - html = b""" - - Foo - - Some text. - - Some more text. -

Text

- More text - - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual( - og, - { - "og:title": "Foo", - "og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text", - }, - ) - - def test_script(self): - html = b""" - - Foo - - - Some text. - - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_missing_title(self): - html = b""" - - - Some text. - - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - - def test_h1_as_title(self): - html = b""" - - - -

Title

- - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) - - def test_missing_title_and_broken_h1(self): - html = b""" - - -

- Some text. - - - """ - - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - - self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) - - def test_empty(self): - """Test a body with no data in it.""" - html = b"" - tree = decode_body(html, "http://example.com/test.html") - self.assertIsNone(tree) - - def test_no_tree(self): - """A valid body with no tree in it.""" - html = b"\x00" - tree = decode_body(html, "http://example.com/test.html") - self.assertIsNone(tree) - - def test_xml(self): - """Test decoding XML and ensure it works properly.""" - # Note that the strip() call is important to ensure the xml tag starts - # at the initial byte. - html = b""" - - - - - FooSome text. - """.strip() - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_invalid_encoding(self): - """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" - html = b""" - - Foo - - Some text. - - - """ - tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) - - def test_invalid_encoding2(self): - """A body which doesn't match the sent character encoding.""" - # Note that this contains an invalid UTF-8 sequence in the title. - html = b""" - - \xff\xff Foo - - Some text. - - - """ - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) - - def test_windows_1252(self): - """A body which uses cp1252, but doesn't declare that.""" - html = b""" - - \xf3 - - Some text. - - - """ - tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") - self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) - - -class MediaEncodingTestCase(unittest.TestCase): - def test_meta_charset(self): - """A character encoding is found via the meta tag.""" - encodings = _get_html_media_encodings( - b""" - - - - - """, - "text/html", - ) - self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - - # A less well-formed version. - encodings = _get_html_media_encodings( - b""" - - < meta charset = ascii> - - - """, - "text/html", - ) - self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - - def test_meta_charset_underscores(self): - """A character encoding contains underscore.""" - encodings = _get_html_media_encodings( - b""" - - - - - """, - "text/html", - ) - self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"]) - - def test_xml_encoding(self): - """A character encoding is found via the meta tag.""" - encodings = _get_html_media_encodings( - b""" - - - - """, - "text/html", - ) - self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - - def test_meta_xml_encoding(self): - """Meta tags take precedence over XML encoding.""" - encodings = _get_html_media_encodings( - b""" - - - - - - """, - "text/html", - ) - self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"]) - - def test_content_type(self): - """A character encoding is found via the Content-Type header.""" - # Test a few variations of the header. - headers = ( - 'text/html; charset="ascii";', - "text/html;charset=ascii;", - 'text/html; charset="ascii"', - "text/html; charset=ascii", - 'text/html; charset="ascii;', - 'text/html; charset=ascii";', - ) - for header in headers: - encodings = _get_html_media_encodings(b"", header) - self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) - - def test_fallback(self): - """A character encoding cannot be found in the body or header.""" - encodings = _get_html_media_encodings(b"", "text/html") - self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - - def test_duplicates(self): - """Ensure each encoding is only attempted once.""" - encodings = _get_html_media_encodings( - b""" - - - - - - """, - 'text/html; charset="UTF_8"', - ) - self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - - def test_unknown_invalid(self): - """A character encoding should be ignored if it is unknown or invalid.""" - encodings = _get_html_media_encodings( - b""" - - - - - """, - 'text/html; charset="invalid"', - ) - self.assertEqual(list(encodings), ["utf-8", "cp1252"]) -- cgit 1.5.1 From 0d6cfea9b867a14fa0fa885b04c8cbfdb4a7c4a9 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 25 Jan 2022 13:06:29 +0100 Subject: Add admin API to reset connection timeouts for remote server (#11639) * Fix get federation status of destination if no error occured --- changelog.d/11639.feature | 1 + docs/usage/administration/admin_api/federation.md | 40 +++++++++++++++- synapse/federation/transport/server/__init__.py | 16 ++++--- synapse/federation/transport/server/_base.py | 14 +++--- synapse/federation/transport/server/federation.py | 24 +++++++--- .../federation/transport/server/groups_local.py | 8 ++-- .../federation/transport/server/groups_server.py | 8 ++-- synapse/rest/admin/__init__.py | 6 ++- synapse/rest/admin/federation.py | 44 ++++++++++++++++- tests/rest/admin/test_federation.py | 55 ++++++++++++++++++++-- 10 files changed, 183 insertions(+), 33 deletions(-) create mode 100644 changelog.d/11639.feature (limited to 'synapse/rest') diff --git a/changelog.d/11639.feature b/changelog.d/11639.feature new file mode 100644 index 0000000000..e9f6704f7a --- /dev/null +++ b/changelog.d/11639.feature @@ -0,0 +1 @@ +Add admin API to reset connection timeouts for remote server. \ No newline at end of file diff --git a/docs/usage/administration/admin_api/federation.md b/docs/usage/administration/admin_api/federation.md index 8f9535f57b..5e609561a6 100644 --- a/docs/usage/administration/admin_api/federation.md +++ b/docs/usage/administration/admin_api/federation.md @@ -86,7 +86,7 @@ The following fields are returned in the JSON response body: - `next_token`: string representing a positive integer - Indication for pagination. See above. - `total` - integer - Total number of destinations. -# Destination Details API +## Destination Details API This API gets the retry timing info for a specific remote server. @@ -108,7 +108,45 @@ A response body like the following is returned: } ``` +**Parameters** + +The following parameters should be set in the URL: + +- `destination` - Name of the remote server. + **Response** The response fields are the same like in the `destinations` array in [List of destinations](#list-of-destinations) response. + +## Reset connection timeout + +Synapse makes federation requests to other homeservers. If a federation request fails, +Synapse will mark the destination homeserver as offline, preventing any future requests +to that server for a "cooldown" period. This period grows over time if the server +continues to fail its responses +([exponential backoff](https://en.wikipedia.org/wiki/Exponential_backoff)). + +Admins can cancel the cooldown period with this API. + +This API resets the retry timing for a specific remote server and tries to connect to +the remote server again. It does not wait for the next `retry_interval`. +The connection must have previously run into an error and `retry_last_ts` +([Destination Details API](#destination-details-api)) must not be equal to `0`. + +The connection attempt is carried out in the background and can take a while +even if the API already returns the http status 200. + +The API is: + +``` +POST /_synapse/admin/v1/federation/destinations//reset_connection + +{} +``` + +**Parameters** + +The following parameters should be set in the URL: + +- `destination` - Name of the remote server. diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 77b936361a..db4fe2c798 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, Iterable, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type from typing_extensions import Literal @@ -36,17 +36,19 @@ from synapse.http.servlet import ( parse_integer_from_args, parse_string_from_args, ) -from synapse.server import HomeServer from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.ratelimitutils import FederationRateLimiter +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class TransportLayerServer(JsonResource): """Handles incoming federation HTTP requests""" - def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None): + def __init__(self, hs: "HomeServer", servlet_groups: Optional[List[str]] = None): """Initialize the TransportLayerServer Will by default register all servlets. For custom behaviour, pass in @@ -113,7 +115,7 @@ class PublicRoomList(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, @@ -203,7 +205,7 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, @@ -251,7 +253,7 @@ class OpenIdUserInfo(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, @@ -297,7 +299,7 @@ DEFAULT_SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = { def register_servlets( - hs: HomeServer, + hs: "HomeServer", resource: HttpServer, authenticator: Authenticator, ratelimiter: FederationRateLimiter, diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index da1fbf8b63..2ca7c05835 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -15,7 +15,7 @@ import functools import logging import re -from typing import Any, Awaitable, Callable, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, cast from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.urls import FEDERATION_V1_PREFIX @@ -29,11 +29,13 @@ from synapse.logging.opentracing import ( start_active_span_follows_from, whitelisted_homeserver, ) -from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import parse_and_validate_server_name +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -46,7 +48,7 @@ class NoAuthenticationError(AuthenticationError): class Authenticator: - def __init__(self, hs: HomeServer): + def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname @@ -114,11 +116,11 @@ class Authenticator: # alive retry_timings = await self.store.get_destination_retry_timings(origin) if retry_timings and retry_timings.retry_last_ts: - run_in_background(self._reset_retry_timings, origin) + run_in_background(self.reset_retry_timings, origin) return origin - async def _reset_retry_timings(self, origin: str) -> None: + async def reset_retry_timings(self, origin: str) -> None: try: logger.info("Marking origin %r as up", origin) await self.store.set_destination_retry_timings(origin, None, 0, 0) @@ -227,7 +229,7 @@ class BaseFederationServlet: def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index beadfa422b..9c1ad5851f 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -12,7 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) from typing_extensions import Literal @@ -30,11 +40,13 @@ from synapse.http.servlet import ( parse_string_from_args, parse_strings_from_args, ) -from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) issue_8631_logger = logging.getLogger("synapse.8631_debug") @@ -47,7 +59,7 @@ class BaseFederationServerServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, @@ -596,7 +608,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, @@ -670,7 +682,7 @@ class FederationRoomHierarchyServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, @@ -706,7 +718,7 @@ class RoomComplexityServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, diff --git a/synapse/federation/transport/server/groups_local.py b/synapse/federation/transport/server/groups_local.py index a12cd18d58..496472e1dc 100644 --- a/synapse/federation/transport/server/groups_local.py +++ b/synapse/federation/transport/server/groups_local.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Tuple, Type from synapse.api.errors import SynapseError from synapse.federation.transport.server._base import ( @@ -19,10 +19,12 @@ from synapse.federation.transport.server._base import ( BaseFederationServlet, ) from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.server import HomeServer from synapse.types import JsonDict, get_domain_from_id from synapse.util.ratelimitutils import FederationRateLimiter +if TYPE_CHECKING: + from synapse.server import HomeServer + class BaseGroupsLocalServlet(BaseFederationServlet): """Abstract base class for federation servlet classes which provides a groups local handler. @@ -32,7 +34,7 @@ class BaseGroupsLocalServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, diff --git a/synapse/federation/transport/server/groups_server.py b/synapse/federation/transport/server/groups_server.py index b30e92a5eb..851b50152e 100644 --- a/synapse/federation/transport/server/groups_server.py +++ b/synapse/federation/transport/server/groups_server.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Tuple, Type from typing_extensions import Literal @@ -22,10 +22,12 @@ from synapse.federation.transport.server._base import ( BaseFederationServlet, ) from synapse.http.servlet import parse_string_from_args -from synapse.server import HomeServer from synapse.types import JsonDict, get_domain_from_id from synapse.util.ratelimitutils import FederationRateLimiter +if TYPE_CHECKING: + from synapse.server import HomeServer + class BaseGroupsServerServlet(BaseFederationServlet): """Abstract base class for federation servlet classes which provides a groups server handler. @@ -35,7 +37,7 @@ class BaseGroupsServerServlet(BaseFederationServlet): def __init__( self, - hs: HomeServer, + hs: "HomeServer", authenticator: Authenticator, ratelimiter: FederationRateLimiter, server_name: str, diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 465e06772b..b1e49d51b7 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -41,7 +41,8 @@ from synapse.rest.admin.event_reports import ( EventReportsRestServlet, ) from synapse.rest.admin.federation import ( - DestinationsRestServlet, + DestinationResetConnectionRestServlet, + DestinationRestServlet, ListDestinationsRestServlet, ) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet @@ -267,7 +268,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListRegistrationTokensRestServlet(hs).register(http_server) NewRegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server) - DestinationsRestServlet(hs).register(http_server) + DestinationResetConnectionRestServlet(hs).register(http_server) + DestinationRestServlet(hs).register(http_server) ListDestinationsRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 8cd3fa189e..0f33f9e4da 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -16,6 +16,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.federation.transport.server import Authenticator from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin @@ -90,7 +91,7 @@ class ListDestinationsRestServlet(RestServlet): return HTTPStatus.OK, response -class DestinationsRestServlet(RestServlet): +class DestinationRestServlet(RestServlet): """Get details of a destination. This needs user to have administrator access in Synapse. @@ -145,3 +146,44 @@ class DestinationsRestServlet(RestServlet): } return HTTPStatus.OK, response + + +class DestinationResetConnectionRestServlet(RestServlet): + """Reset destinations' connection timeouts and wake it up. + This needs user to have administrator access in Synapse. + + POST /_synapse/admin/v1/federation/destinations//reset_connection + {} + + returns: + 200 OK otherwise an error. + """ + + PATTERNS = admin_patterns( + "/federation/destinations/(?P[^/]+)/reset_connection$" + ) + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self._authenticator = Authenticator(hs) + + async def on_POST( + self, request: SynapseRequest, destination: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + if not await self._store.is_destination_known(destination): + raise NotFoundError("Unknown destination") + + retry_timings = await self._store.get_destination_retry_timings(destination) + if not (retry_timings and retry_timings.retry_last_ts): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "The retry timing does not need to be reset for this destination.", + ) + + # reset timings and wake up + await self._authenticator.reset_retry_timings(destination) + + return HTTPStatus.OK, {} diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index b70350b6f1..e2d3cff2a3 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -43,11 +43,15 @@ class FederationTestCase(unittest.HomeserverTestCase): @parameterized.expand( [ - ("/_synapse/admin/v1/federation/destinations",), - ("/_synapse/admin/v1/federation/destinations/dummy",), + ("GET", "/_synapse/admin/v1/federation/destinations"), + ("GET", "/_synapse/admin/v1/federation/destinations/dummy"), + ( + "POST", + "/_synapse/admin/v1/federation/destinations/dummy/reset_connection", + ), ] ) - def test_requester_is_no_admin(self, url: str) -> None: + def test_requester_is_no_admin(self, method: str, url: str) -> None: """ If the user is not a server admin, an error 403 is returned. """ @@ -56,7 +60,7 @@ class FederationTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") channel = self.make_request( - "GET", + method, url, content={}, access_token=other_user_tok, @@ -120,6 +124,16 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + # invalid destination + channel = self.make_request( + "POST", + self.url + "/dummy/reset_connection", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + def test_limit(self) -> None: """ Testing list of destinations with limit @@ -444,6 +458,39 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertIsNone(channel.json_body["failure_ts"]) self.assertIsNone(channel.json_body["last_successful_stream_ordering"]) + def test_destination_reset_connection(self) -> None: + """Reset timeouts and wake up destination.""" + self._create_destination("sub0.example.com", 100, 100, 100) + + channel = self.make_request( + "POST", + self.url + "/sub0.example.com/reset_connection", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + retry_timings = self.get_success( + self.store.get_destination_retry_timings("sub0.example.com") + ) + self.assertIsNone(retry_timings) + + def test_destination_reset_connection_not_required(self) -> None: + """Try to reset timeouts of a destination with no timeouts and get an error.""" + self._create_destination("sub0.example.com", None, 0, 0) + + channel = self.make_request( + "POST", + self.url + "/sub0.example.com/reset_connection", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual( + "The retry timing does not need to be reset for this destination.", + channel.json_body["error"], + ) + def _create_destination( self, destination: str, -- cgit 1.5.1 From 6a72c910f180ee8b4bd78223775af48492769472 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 25 Jan 2022 17:11:40 +0100 Subject: Add admin API to get a list of federated rooms (#11658) --- changelog.d/11658.feature | 1 + docs/usage/administration/admin_api/federation.md | 60 +++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/federation.py | 56 ++++ synapse/storage/databases/main/transactions.py | 48 ++++ tests/rest/admin/test_federation.py | 302 ++++++++++++++++++++-- 6 files changed, 444 insertions(+), 25 deletions(-) create mode 100644 changelog.d/11658.feature (limited to 'synapse/rest') diff --git a/changelog.d/11658.feature b/changelog.d/11658.feature new file mode 100644 index 0000000000..2ec9fb5eec --- /dev/null +++ b/changelog.d/11658.feature @@ -0,0 +1 @@ +Add an admin API to get a list of rooms that federate with a given remote homeserver. \ No newline at end of file diff --git a/docs/usage/administration/admin_api/federation.md b/docs/usage/administration/admin_api/federation.md index 5e609561a6..60cbc5265e 100644 --- a/docs/usage/administration/admin_api/federation.md +++ b/docs/usage/administration/admin_api/federation.md @@ -119,6 +119,66 @@ The following parameters should be set in the URL: The response fields are the same like in the `destinations` array in [List of destinations](#list-of-destinations) response. +## Destination rooms + +This API gets the rooms that federate with a specific remote server. + +The API is: + +``` +GET /_synapse/admin/v1/federation/destinations//rooms +``` + +A response body like the following is returned: + +```json +{ + "rooms":[ + { + "room_id": "!OGEhHVWSdvArJzumhm:matrix.org", + "stream_ordering": 8326 + }, + { + "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", + "stream_ordering": 93534 + } + ], + "total": 2 +} +``` + +To paginate, check for `next_token` and if present, call the endpoint again +with `from` set to the value of `next_token`. This will return a new page. + +If the endpoint does not return a `next_token` then there are no more destinations +to paginate through. + +**Parameters** + +The following parameters should be set in the URL: + +- `destination` - Name of the remote server. + +The following query parameters are available: + +- `from` - Offset in the returned list. Defaults to `0`. +- `limit` - Maximum amount of destinations to return. Defaults to `100`. +- `dir` - Direction of room order by `room_id`. Either `f` for forwards or `b` for + backwards. Defaults to `f`. + +**Response** + +The following fields are returned in the JSON response body: + +- `rooms` - An array of objects, each containing information about a room. + Room objects contain the following fields: + - `room_id` - string - The ID of the room. + - `stream_ordering` - integer - The stream ordering of the most recent + successfully-sent [PDU](understanding_synapse_through_grafana_graphs.md#federation) + to this destination in this room. +- `next_token`: string representing a positive integer - Indication for pagination. See above. +- `total` - integer - Total number of destinations. + ## Reset connection timeout Synapse makes federation requests to other homeservers. If a federation request fails, diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index b1e49d51b7..9be9e33c8e 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -41,6 +41,7 @@ from synapse.rest.admin.event_reports import ( EventReportsRestServlet, ) from synapse.rest.admin.federation import ( + DestinationMembershipRestServlet, DestinationResetConnectionRestServlet, DestinationRestServlet, ListDestinationsRestServlet, @@ -268,6 +269,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListRegistrationTokensRestServlet(hs).register(http_server) NewRegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server) + DestinationMembershipRestServlet(hs).register(http_server) DestinationResetConnectionRestServlet(hs).register(http_server) DestinationRestServlet(hs).register(http_server) ListDestinationsRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 0f33f9e4da..d162e0081e 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -148,6 +148,62 @@ class DestinationRestServlet(RestServlet): return HTTPStatus.OK, response +class DestinationMembershipRestServlet(RestServlet): + """Get list of rooms of a destination. + This needs user to have administrator access in Synapse. + + GET /_synapse/admin/v1/federation/destinations//rooms?from=0&limit=10 + + returns: + 200 OK with a list of rooms if success otherwise an error. + + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + """ + + PATTERNS = admin_patterns("/federation/destinations/(?P[^/]*)/rooms$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET( + self, request: SynapseRequest, destination: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + if not await self._store.is_destination_known(destination): + raise NotFoundError("Unknown destination") + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) + + rooms, total = await self._store.get_destination_rooms_paginate( + destination, start, limit, direction + ) + response = {"rooms": rooms, "total": total} + if (start + limit) < total: + response["next_token"] = str(start + len(rooms)) + + return HTTPStatus.OK, response + + class DestinationResetConnectionRestServlet(RestServlet): """Reset destinations' connection timeouts and wake it up. This needs user to have administrator access in Synapse. diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 4b78b4d098..ba79e19f7f 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -561,6 +561,54 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): "get_destinations_paginate_txn", get_destinations_paginate_txn ) + async def get_destination_rooms_paginate( + self, destination: str, start: int, limit: int, direction: str = "f" + ) -> Tuple[List[JsonDict], int]: + """Function to retrieve a paginated list of destination's rooms. + This will return a json list of rooms and the + total number of rooms. + + Args: + destination: the destination to query + start: start number to begin the query from + limit: number of rows to retrieve + direction: sort ascending or descending by room_id + Returns: + A tuple of a dict of rooms and a count of total rooms. + """ + + def get_destination_rooms_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + sql = """ + SELECT COUNT(*) as total_rooms + FROM destination_rooms + WHERE destination = ? + """ + txn.execute(sql, [destination]) + count = cast(Tuple[int], txn.fetchone())[0] + + rooms = self.db_pool.simple_select_list_paginate_txn( + txn=txn, + table="destination_rooms", + orderby="room_id", + start=start, + limit=limit, + retcols=("room_id", "stream_ordering"), + order_direction=order, + ) + return rooms, count + + return await self.db_pool.runInteraction( + "get_destination_rooms_paginate_txn", get_destination_rooms_paginate_txn + ) + async def is_destination_known(self, destination: str) -> bool: """Check if a destination is known to the server.""" result = await self.db_pool.simple_select_one_onecol( diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index e2d3cff2a3..71068d16cd 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -20,7 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client import login +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -52,9 +52,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ] ) def test_requester_is_no_admin(self, method: str, url: str) -> None: - """ - If the user is not a server admin, an error 403 is returned. - """ + """If the user is not a server admin, an error 403 is returned.""" self.register_user("user", "pass", admin=False) other_user_tok = self.login("user", "pass") @@ -70,9 +68,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: - """ - If parameters are invalid, an error is returned. - """ + """If parameters are invalid, an error is returned.""" # negative limit channel = self.make_request( @@ -135,9 +131,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_limit(self) -> None: - """ - Testing list of destinations with limit - """ + """Testing list of destinations with limit""" number_destinations = 20 self._create_destinations(number_destinations) @@ -155,9 +149,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self._check_fields(channel.json_body["destinations"]) def test_from(self) -> None: - """ - Testing list of destinations with a defined starting point (from) - """ + """Testing list of destinations with a defined starting point (from)""" number_destinations = 20 self._create_destinations(number_destinations) @@ -175,9 +167,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self._check_fields(channel.json_body["destinations"]) def test_limit_and_from(self) -> None: - """ - Testing list of destinations with a defined starting point and limit - """ + """Testing list of destinations with a defined starting point and limit""" number_destinations = 20 self._create_destinations(number_destinations) @@ -195,9 +185,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self._check_fields(channel.json_body["destinations"]) def test_next_token(self) -> None: - """ - Testing that `next_token` appears at the right place - """ + """Testing that `next_token` appears at the right place""" number_destinations = 20 self._create_destinations(number_destinations) @@ -256,9 +244,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) def test_list_all_destinations(self) -> None: - """ - List all destinations. - """ + """List all destinations.""" number_destinations = 5 self._create_destinations(number_destinations) @@ -277,9 +263,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self._check_fields(channel.json_body["destinations"]) def test_order_by(self) -> None: - """ - Testing order list with parameter `order_by` - """ + """Testing order list with parameter `order_by`""" def _order_test( expected_destination_list: List[str], @@ -543,3 +527,271 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertIn("retry_interval", c) self.assertIn("failure_ts", c) self.assertIn("last_successful_stream_ordering", c) + + +class DestinationMembershipTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastore() + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.dest = "sub0.example.com" + self.url = f"/_synapse/admin/v1/federation/destinations/{self.dest}/rooms" + + # Record that we successfully contacted a destination in the DB. + self.get_success( + self.store.set_destination_retry_timings(self.dest, None, 0, 0) + ) + + def test_requester_is_no_admin(self) -> None: + """If the user is not a server admin, an error 403 is returned.""" + + self.register_user("user", "pass", admin=False) + other_user_tok = self.login("user", "pass") + + channel = self.make_request( + "GET", + self.url, + access_token=other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self) -> None: + """If parameters are invalid, an error is returned.""" + + # negative limit + channel = self.make_request( + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative from + channel = self.make_request( + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # invalid search order + channel = self.make_request( + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # invalid destination + channel = self.make_request( + "GET", + "/_synapse/admin/v1/federation/destinations/%s/rooms" % ("invalid",), + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_limit(self) -> None: + """Testing list of destinations with limit""" + + number_rooms = 5 + self._create_destination_rooms(number_rooms) + + channel = self.make_request( + "GET", + self.url + "?limit=3", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(len(channel.json_body["rooms"]), 3) + self.assertEqual(channel.json_body["next_token"], "3") + self._check_fields(channel.json_body["rooms"]) + + def test_from(self) -> None: + """Testing list of rooms with a defined starting point (from)""" + + number_rooms = 10 + self._create_destination_rooms(number_rooms) + + channel = self.make_request( + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(len(channel.json_body["rooms"]), 5) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["rooms"]) + + def test_limit_and_from(self) -> None: + """Testing list of rooms with a defined starting point and limit""" + + number_rooms = 10 + self._create_destination_rooms(number_rooms) + + channel = self.make_request( + "GET", + self.url + "?from=3&limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(channel.json_body["next_token"], "8") + self.assertEqual(len(channel.json_body["rooms"]), 5) + self._check_fields(channel.json_body["rooms"]) + + def test_order_direction(self) -> None: + """Testing order list with parameter `dir`""" + number_rooms = 4 + self._create_destination_rooms(number_rooms) + + # get list in forward direction + channel_asc = self.make_request( + "GET", + self.url + "?dir=f", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel_asc.code, msg=channel_asc.json_body) + self.assertEqual(channel_asc.json_body["total"], number_rooms) + self.assertEqual(number_rooms, len(channel_asc.json_body["rooms"])) + self._check_fields(channel_asc.json_body["rooms"]) + + # get list in backward direction + channel_desc = self.make_request( + "GET", + self.url + "?dir=b", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel_desc.code, msg=channel_desc.json_body) + self.assertEqual(channel_desc.json_body["total"], number_rooms) + self.assertEqual(number_rooms, len(channel_desc.json_body["rooms"])) + self._check_fields(channel_desc.json_body["rooms"]) + + # test that both lists have different directions + for i in range(0, number_rooms): + self.assertEqual( + channel_asc.json_body["rooms"][i]["room_id"], + channel_desc.json_body["rooms"][number_rooms - 1 - i]["room_id"], + ) + + def test_next_token(self) -> None: + """Testing that `next_token` appears at the right place""" + + number_rooms = 5 + self._create_destination_rooms(number_rooms) + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(len(channel.json_body["rooms"]), number_rooms) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=6", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(len(channel.json_body["rooms"]), number_rooms) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=4", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(len(channel.json_body["rooms"]), 4) + self.assertEqual(channel.json_body["next_token"], "4") + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", + self.url + "?from=4", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(len(channel.json_body["rooms"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def test_destination_rooms(self) -> None: + """Testing that request the list of rooms is successfully.""" + number_rooms = 3 + self._create_destination_rooms(number_rooms) + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_rooms) + self.assertEqual(number_rooms, len(channel.json_body["rooms"])) + self._check_fields(channel.json_body["rooms"]) + + def _create_destination_rooms(self, number_rooms: int) -> None: + """Create a number rooms for destination + + Args: + number_rooms: Number of rooms to be created + """ + for _ in range(0, number_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + self.get_success( + self.store.store_destination_rooms_entries((self.dest,), room_id, 1234) + ) + + def _check_fields(self, content: List[JsonDict]) -> None: + """Checks that the expected room attributes are present in content + + Args: + content: List that is checked for content + """ + for c in content: + self.assertIn("room_id", c) + self.assertIn("stream_ordering", c) -- cgit 1.5.1 From 95b3f952fa43e51feae166fa1678761c5e32d900 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 26 Jan 2022 12:02:54 +0000 Subject: Add a config flag to inhibit `M_USER_IN_USE` during registration (#11743) This is mostly motivated by the tchap use case, where usernames are automatically generated from the user's email address (in a way that allows figuring out the email address from the username). Therefore, it's an issue if we respond to requests on /register and /register/available with M_USER_IN_USE, because it can potentially leak email addresses (which include the user's real name and place of work). This commit adds a flag to inhibit the M_USER_IN_USE errors that are raised both by /register/available, and when providing a username early into the registration process. This error will still be raised if the user completes the registration process but the username conflicts. This is particularly useful when using modules (https://github.com/matrix-org/synapse/pull/11790 adds a module callback to set the username of users at registration) or SSO, since they can ensure the username is unique. More context is available in the PR that introduced this behaviour to synapse-dinsic: matrix-org/synapse-dinsic#48 - as well as the issue in the matrix-dinsic repo: matrix-org/matrix-dinsic#476 --- changelog.d/11743.feature | 1 + docs/sample_config.yaml | 10 ++++++++++ synapse/config/registration.py | 12 +++++++++++ synapse/handlers/register.py | 26 +++++++++++++----------- synapse/rest/client/register.py | 11 ++++++++++ tests/rest/client/test_register.py | 41 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 89 insertions(+), 12 deletions(-) create mode 100644 changelog.d/11743.feature (limited to 'synapse/rest') diff --git a/changelog.d/11743.feature b/changelog.d/11743.feature new file mode 100644 index 0000000000..9809f48b96 --- /dev/null +++ b/changelog.d/11743.feature @@ -0,0 +1 @@ +Add a config flag to inhibit M_USER_IN_USE during registration. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 1b86d0295d..b38e6d6c88 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1428,6 +1428,16 @@ account_threepid_delegates: # #auto_join_rooms_for_guests: false +# Whether to inhibit errors raised when registering a new account if the user ID +# already exists. If turned on, that requests to /register/available will always +# show a user ID as available, and Synapse won't raise an error when starting +# a registration with a user ID that already exists. However, Synapse will still +# raise an error if the registration completes and the username conflicts. +# +# Defaults to false. +# +#inhibit_user_in_use_error: true + ## Metrics ### diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 7a059c6dec..ea9b50fe97 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -190,6 +190,8 @@ class RegistrationConfig(Config): # The success template used during fallback auth. self.fallback_success_template = self.read_template("auth_success.html") + self.inhibit_user_in_use_error = config.get("inhibit_user_in_use_error", False) + def generate_config_section(self, generate_secrets=False, **kwargs): if generate_secrets: registration_shared_secret = 'registration_shared_secret: "%s"' % ( @@ -446,6 +448,16 @@ class RegistrationConfig(Config): # Defaults to true. # #auto_join_rooms_for_guests: false + + # Whether to inhibit errors raised when registering a new account if the user ID + # already exists. If turned on, that requests to /register/available will always + # show a user ID as available, and Synapse won't raise an error when starting + # a registration with a user ID that already exists. However, Synapse will still + # raise an error if the registration completes and the username conflicts. + # + # Defaults to false. + # + #inhibit_user_in_use_error: true """ % locals() ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f08a516a75..a719d5eef3 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -132,6 +132,7 @@ class RegistrationHandler: localpart: str, guest_access_token: Optional[str] = None, assigned_user_id: Optional[str] = None, + inhibit_user_in_use_error: bool = False, ) -> None: if types.contains_invalid_mxid_characters(localpart): raise SynapseError( @@ -171,21 +172,22 @@ class RegistrationHandler: users = await self.store.get_users_by_id_case_insensitive(user_id) if users: - if not guest_access_token: + if not inhibit_user_in_use_error and not guest_access_token: raise SynapseError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) - user_data = await self.auth.get_user_by_access_token(guest_access_token) - if ( - not user_data.is_guest - or UserID.from_string(user_data.user_id).localpart != localpart - ): - raise AuthError( - 403, - "Cannot register taken user ID without valid guest " - "credentials for that user.", - errcode=Codes.FORBIDDEN, - ) + if guest_access_token: + user_data = await self.auth.get_user_by_access_token(guest_access_token) + if ( + not user_data.is_guest + or UserID.from_string(user_data.user_id).localpart != localpart + ): + raise AuthError( + 403, + "Cannot register taken user ID without valid guest " + "credentials for that user.", + errcode=Codes.FORBIDDEN, + ) if guest_access_token is None: try: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 8b56c76aed..c59dae7c03 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -339,12 +339,19 @@ class UsernameAvailabilityRestServlet(RestServlet): ), ) + self.inhibit_user_in_use_error = ( + hs.config.registration.inhibit_user_in_use_error + ) + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) + if self.inhibit_user_in_use_error: + return 200, {"available": True} + ip = request.getClientIP() with self.ratelimiter.ratelimit(ip) as wait_deferred: await wait_deferred @@ -422,6 +429,9 @@ class RegisterRestServlet(RestServlet): self._refresh_tokens_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None ) + self._inhibit_user_in_use_error = ( + hs.config.registration.inhibit_user_in_use_error + ) self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -564,6 +574,7 @@ class RegisterRestServlet(RestServlet): desired_username, guest_access_token=guest_access_token, assigned_user_id=registered_user_id, + inhibit_user_in_use_error=self._inhibit_user_in_use_error, ) # Check if the user-interactive authentication flows are complete, if diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 6e7c0f11df..407dd32a73 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -726,6 +726,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, ) + @override_config( + { + "inhibit_user_in_use_error": True, + } + ) + def test_inhibit_user_in_use_error(self): + """Tests that the 'inhibit_user_in_use_error' configuration flag behaves + correctly. + """ + username = "arthur" + + # Manually register the user, so we know the test isn't passing because of a lack + # of clashing. + reg_handler = self.hs.get_registration_handler() + self.get_success(reg_handler.register_user(username)) + + # Check that /available correctly ignores the username provided despite the + # username being already registered. + channel = self.make_request("GET", "register/available?username=" + username) + self.assertEquals(200, channel.code, channel.result) + + # Test that when starting a UIA registration flow the request doesn't fail because + # of a conflicting username + channel = self.make_request( + "POST", + "register", + {"username": username, "type": "m.login.password", "password": "foo"}, + ) + self.assertEqual(channel.code, 401) + self.assertIn("session", channel.json_body) + + # Test that finishing the registration fails because of a conflicting username. + session = channel.json_body["session"] + channel = self.make_request( + "POST", + "register", + {"auth": {"session": session, "type": LoginType.DUMMY}}, + ) + self.assertEqual(channel.code, 400, channel.json_body) + self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE) + class AccountValidityTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 2897fb6b4fb8bdaea0e919233d5ccaf5dea12742 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 Jan 2022 08:27:04 -0500 Subject: Improvements to bundling aggregations. (#11815) This is some odds and ends found during the review of #11791 and while continuing to work in this code: * Return attrs classes instead of dictionaries from some methods to improve type safety. * Call `get_bundled_aggregations` fewer times. * Adds a missing assertion in the tests. * Do not return empty bundled aggregations for an event (preferring to not include the bundle at all, as the docstring states). --- changelog.d/11815.misc | 1 + synapse/events/utils.py | 57 ++++++++++++++------- synapse/handlers/room.py | 77 +++++++++++++++-------------- synapse/handlers/search.py | 45 ++++++++--------- synapse/handlers/sync.py | 3 +- synapse/push/mailer.py | 2 +- synapse/rest/admin/rooms.py | 39 +++++++++------ synapse/rest/client/room.py | 39 +++++++++------ synapse/rest/client/sync.py | 3 +- synapse/storage/databases/main/relations.py | 61 ++++++++++++++--------- synapse/storage/databases/main/stream.py | 22 ++++++--- tests/rest/client/test_relations.py | 2 +- 12 files changed, 212 insertions(+), 139 deletions(-) create mode 100644 changelog.d/11815.misc (limited to 'synapse/rest') diff --git a/changelog.d/11815.misc b/changelog.d/11815.misc new file mode 100644 index 0000000000..83aa6d6eb0 --- /dev/null +++ b/changelog.d/11815.misc @@ -0,0 +1 @@ +Improve type safety of bundled aggregations code. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 918adeecf8..243696b357 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -14,7 +14,17 @@ # limitations under the License. import collections.abc import re -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Union, +) from frozendict import frozendict @@ -26,6 +36,10 @@ from synapse.util.frozenutils import unfreeze from . import EventBase +if TYPE_CHECKING: + from synapse.storage.databases.main.relations import BundledAggregations + + # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # (? JsonDict: """Serializes a single event. @@ -415,7 +429,7 @@ class EventClientSerializer: self, event: EventBase, time_now: int, - aggregations: JsonDict, + aggregations: "BundledAggregations", serialized_event: JsonDict, ) -> None: """Potentially injects bundled aggregations into the unsigned portion of the serialized event. @@ -427,13 +441,18 @@ class EventClientSerializer: serialized_event: The serialized event which may be modified. """ - # Make a copy in-case the object is cached. - aggregations = aggregations.copy() + serialized_aggregations = {} + + if aggregations.annotations: + serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations + + if aggregations.references: + serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references - if RelationTypes.REPLACE in aggregations: + if aggregations.replace: # If there is an edit replace the content, preserving existing # relations. - edit = aggregations[RelationTypes.REPLACE] + edit = aggregations.replace # Ensure we take copies of the edit content, otherwise we risk modifying # the original event. @@ -451,24 +470,28 @@ class EventClientSerializer: else: serialized_event["content"].pop("m.relates_to", None) - aggregations[RelationTypes.REPLACE] = { + serialized_aggregations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, "sender": edit.sender, } # If this event is the start of a thread, include a summary of the replies. - if RelationTypes.THREAD in aggregations: - # Serialize the latest thread event. - latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"] - - # Don't bundle aggregations as this could recurse forever. - aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event( - latest_thread_event, time_now, bundle_aggregations=None - ) + if aggregations.thread: + serialized_aggregations[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. + "latest_event": self.serialize_event( + aggregations.thread.latest_event, time_now, bundle_aggregations=None + ), + "count": aggregations.thread.count, + "current_user_participated": aggregations.thread.current_user_participated, + } # Include the bundled aggregations in the event. - serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations) + if serialized_aggregations: + serialized_event["unsigned"].setdefault("m.relations", {}).update( + serialized_aggregations + ) def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f963078e59..1420d67729 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -30,6 +30,7 @@ from typing import ( Tuple, ) +import attr from typing_extensions import TypedDict from synapse.api.constants import ( @@ -60,6 +61,7 @@ from synapse.events.utils import copy_power_levels_contents from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.federation import get_domains_from_state from synapse.rest.admin._base import assert_user_is_admin +from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( @@ -90,6 +92,17 @@ id_server_scheme = "https://" FIVE_MINUTES_IN_MS = 5 * 60 * 1000 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventContext: + events_before: List[EventBase] + event: EventBase + events_after: List[EventBase] + state: List[EventBase] + aggregations: Dict[str, BundledAggregations] + start: str + end: str + + class RoomCreationHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -1119,7 +1132,7 @@ class RoomContextHandler: limit: int, event_filter: Optional[Filter], use_admin_priviledge: bool = False, - ) -> Optional[JsonDict]: + ) -> Optional[EventContext]: """Retrieves events, pagination tokens and state around a given event in a room. @@ -1167,38 +1180,28 @@ class RoomContextHandler: results = await self.store.get_events_around( room_id, event_id, before_limit, after_limit, event_filter ) + events_before = results.events_before + events_after = results.events_after if event_filter: - results["events_before"] = await event_filter.filter( - results["events_before"] - ) - results["events_after"] = await event_filter.filter(results["events_after"]) + events_before = await event_filter.filter(events_before) + events_after = await event_filter.filter(events_after) - results["events_before"] = await filter_evts(results["events_before"]) - results["events_after"] = await filter_evts(results["events_after"]) + events_before = await filter_evts(events_before) + events_after = await filter_evts(events_after) # filter_evts can return a pruned event in case the user is allowed to see that # there's something there but not see the content, so use the event that's in # `filtered` rather than the event we retrieved from the datastore. - results["event"] = filtered[0] + event = filtered[0] # Fetch the aggregations. aggregations = await self.store.get_bundled_aggregations( - [results["event"]], user.to_string() + itertools.chain(events_before, (event,), events_after), + user.to_string(), ) - aggregations.update( - await self.store.get_bundled_aggregations( - results["events_before"], user.to_string() - ) - ) - aggregations.update( - await self.store.get_bundled_aggregations( - results["events_after"], user.to_string() - ) - ) - results["aggregations"] = aggregations - if results["events_after"]: - last_event_id = results["events_after"][-1].event_id + if events_after: + last_event_id = events_after[-1].event_id else: last_event_id = event_id @@ -1206,9 +1209,9 @@ class RoomContextHandler: state_filter = StateFilter.from_lazy_load_member_list( ev.sender for ev in itertools.chain( - results["events_before"], - (results["event"],), - results["events_after"], + events_before, + (event,), + events_after, ) ) else: @@ -1226,21 +1229,23 @@ class RoomContextHandler: if event_filter: state_events = await event_filter.filter(state_events) - results["state"] = await filter_evts(state_events) - # We use a dummy token here as we only care about the room portion of # the token, which we replace. token = StreamToken.START - results["start"] = await token.copy_and_replace( - "room_key", results["start"] - ).to_string(self.store) - - results["end"] = await token.copy_and_replace( - "room_key", results["end"] - ).to_string(self.store) - - return results + return EventContext( + events_before=events_before, + event=event, + events_after=events_after, + state=await filter_evts(state_events), + aggregations=aggregations, + start=await token.copy_and_replace("room_key", results.start).to_string( + self.store + ), + end=await token.copy_and_replace("room_key", results.end).to_string( + self.store + ), + ) class TimestampLookupHandler: diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 0b153a6822..02bb5ae72f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -361,36 +361,37 @@ class SearchHandler: logger.info( "Context for search returned %d and %d events", - len(res["events_before"]), - len(res["events_after"]), + len(res.events_before), + len(res.events_after), ) - res["events_before"] = await filter_events_for_client( - self.storage, user.to_string(), res["events_before"] + events_before = await filter_events_for_client( + self.storage, user.to_string(), res.events_before ) - res["events_after"] = await filter_events_for_client( - self.storage, user.to_string(), res["events_after"] + events_after = await filter_events_for_client( + self.storage, user.to_string(), res.events_after ) - res["start"] = await now_token.copy_and_replace( - "room_key", res["start"] - ).to_string(self.store) - - res["end"] = await now_token.copy_and_replace( - "room_key", res["end"] - ).to_string(self.store) + context = { + "events_before": events_before, + "events_after": events_after, + "start": await now_token.copy_and_replace( + "room_key", res.start + ).to_string(self.store), + "end": await now_token.copy_and_replace( + "room_key", res.end + ).to_string(self.store), + } if include_profile: senders = { ev.sender - for ev in itertools.chain( - res["events_before"], [event], res["events_after"] - ) + for ev in itertools.chain(events_before, [event], events_after) } - if res["events_after"]: - last_event_id = res["events_after"][-1].event_id + if events_after: + last_event_id = events_after[-1].event_id else: last_event_id = event.event_id @@ -402,7 +403,7 @@ class SearchHandler: last_event_id, state_filter ) - res["profile_info"] = { + context["profile_info"] = { s.state_key: { "displayname": s.content.get("displayname", None), "avatar_url": s.content.get("avatar_url", None), @@ -411,7 +412,7 @@ class SearchHandler: if s.type == EventTypes.Member and s.state_key in senders } - contexts[event.event_id] = res + contexts[event.event_id] = context else: contexts = {} @@ -421,10 +422,10 @@ class SearchHandler: for context in contexts.values(): context["events_before"] = self._event_serializer.serialize_events( - context["events_before"], time_now + context["events_before"], time_now # type: ignore[arg-type] ) context["events_after"] = self._event_serializer.serialize_events( - context["events_after"], time_now + context["events_after"], time_now # type: ignore[arg-type] ) state_results = {} diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7e2a892b63..c72ed7c290 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -37,6 +37,7 @@ from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -100,7 +101,7 @@ class TimelineBatch: limited: bool # A mapping of event ID to the bundled aggregations for the above events. # This is only calculated if limited is true. - bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None + bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index dadfc57413..3df8452eec 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -455,7 +455,7 @@ class Mailer: } the_events = await filter_events_for_client( - self.storage, user_id, results["events_before"] + self.storage, user_id, results.events_before ) the_events.append(notif_event) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index efe25fe7eb..5b706efbcf 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -729,7 +729,7 @@ class RoomEventContextServlet(RestServlet): else: event_filter = None - results = await self.room_context_handler.get_event_context( + event_context = await self.room_context_handler.get_event_context( requester, room_id, event_id, @@ -738,25 +738,34 @@ class RoomEventContextServlet(RestServlet): use_admin_priviledge=True, ) - if not results: + if not event_context: raise SynapseError( HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND ) time_now = self.clock.time_msec() - aggregations = results.pop("aggregations", None) - results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], time_now, bundle_aggregations=aggregations - ) - results["event"] = self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=aggregations - ) - results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], time_now, bundle_aggregations=aggregations - ) - results["state"] = self._event_serializer.serialize_events( - results["state"], time_now - ) + results = { + "events_before": self._event_serializer.serialize_events( + event_context.events_before, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "event": self._event_serializer.serialize_event( + event_context.event, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "events_after": self._event_serializer.serialize_events( + event_context.events_after, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "state": self._event_serializer.serialize_events( + event_context.state, time_now + ), + "start": event_context.start, + "end": event_context.end, + } return HTTPStatus.OK, results diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 90bb9142a0..90355e44b2 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -706,27 +706,36 @@ class RoomEventContextServlet(RestServlet): else: event_filter = None - results = await self.room_context_handler.get_event_context( + event_context = await self.room_context_handler.get_event_context( requester, room_id, event_id, limit, event_filter ) - if not results: + if not event_context: raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - aggregations = results.pop("aggregations", None) - results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], time_now, bundle_aggregations=aggregations - ) - results["event"] = self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=aggregations - ) - results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], time_now, bundle_aggregations=aggregations - ) - results["state"] = self._event_serializer.serialize_events( - results["state"], time_now - ) + results = { + "events_before": self._event_serializer.serialize_events( + event_context.events_before, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "event": self._event_serializer.serialize_event( + event_context.event, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "events_after": self._event_serializer.serialize_events( + event_context.events_after, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "state": self._event_serializer.serialize_events( + event_context.state, time_now + ), + "start": event_context.start, + "end": event_context.end, + } return 200, results diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index d20ae1421e..f9615da525 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -48,6 +48,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.opentracing import trace +from synapse.storage.databases.main.relations import BundledAggregations from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder @@ -526,7 +527,7 @@ class SyncRestServlet(RestServlet): def serialize( events: Iterable[EventBase], - aggregations: Optional[Dict[str, Dict[str, Any]]] = None, + aggregations: Optional[Dict[str, BundledAggregations]] = None, ) -> List[JsonDict]: return self._event_serializer.serialize_events( events, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 2cb5d06c13..a9a5dd5f03 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,17 +13,7 @@ # limitations under the License. import logging -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - List, - Optional, - Tuple, - Union, - cast, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast import attr from frozendict import frozendict @@ -43,6 +33,7 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -51,6 +42,30 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _ThreadAggregation: + latest_event: EventBase + count: int + current_user_participated: bool + + +@attr.s(slots=True, auto_attribs=True) +class BundledAggregations: + """ + The bundled aggregations for an event. + + Some values require additional processing during serialization. + """ + + annotations: Optional[JsonDict] = None + references: Optional[JsonDict] = None + replace: Optional[EventBase] = None + thread: Optional[_ThreadAggregation] = None + + def __bool__(self) -> bool: + return bool(self.annotations or self.references or self.replace or self.thread) + + class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -585,7 +600,7 @@ class RelationsWorkerStore(SQLBaseStore): async def _get_bundled_aggregation_for_event( self, event: EventBase, user_id: str - ) -> Optional[Dict[str, Any]]: + ) -> Optional[BundledAggregations]: """Generate bundled aggregations for an event. Note that this does not use a cache, but depends on cached methods. @@ -616,24 +631,24 @@ class RelationsWorkerStore(SQLBaseStore): # The bundled aggregations to include, a mapping of relation type to a # type-specific value. Some types include the direct return type here # while others need more processing during serialization. - aggregations: Dict[str, Any] = {} + aggregations = BundledAggregations() annotations = await self.get_aggregation_groups_for_event(event_id, room_id) if annotations.chunk: - aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + aggregations.annotations = annotations.to_dict() references = await self.get_relations_for_event( event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: - aggregations[RelationTypes.REFERENCE] = references.to_dict() + aggregations.references = references.to_dict() edit = None if event.type == EventTypes.Message: edit = await self.get_applicable_edit(event_id, room_id) if edit: - aggregations[RelationTypes.REPLACE] = edit + aggregations.replace = edit # If this event is the start of a thread, include a summary of the replies. if self._msc3440_enabled: @@ -644,11 +659,11 @@ class RelationsWorkerStore(SQLBaseStore): event_id, room_id, user_id ) if latest_thread_event: - aggregations[RelationTypes.THREAD] = { - "latest_event": latest_thread_event, - "count": thread_count, - "current_user_participated": participated, - } + aggregations.thread = _ThreadAggregation( + latest_event=latest_thread_event, + count=thread_count, + current_user_participated=participated, + ) # Store the bundled aggregations in the event metadata for later use. return aggregations @@ -657,7 +672,7 @@ class RelationsWorkerStore(SQLBaseStore): self, events: Iterable[EventBase], user_id: str, - ) -> Dict[str, Dict[str, Any]]: + ) -> Dict[str, BundledAggregations]: """Generate bundled aggregations for events. Args: @@ -676,7 +691,7 @@ class RelationsWorkerStore(SQLBaseStore): results = {} for event in events: event_result = await self._get_bundled_aggregation_for_event(event, user_id) - if event_result is not None: + if event_result: results[event.event_id] = event_result return results diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 319464b1fa..a898f847e7 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -81,6 +81,14 @@ class _EventDictReturn: stream_ordering: int +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventsAround: + events_before: List[EventBase] + events_after: List[EventBase] + start: RoomStreamToken + end: RoomStreamToken + + def generate_pagination_where_clause( direction: str, column_names: Tuple[str, str], @@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): before_limit: int, after_limit: int, event_filter: Optional[Filter] = None, - ) -> dict: + ) -> _EventsAround: """Retrieve events and pagination tokens around a given event in a room. """ @@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): list(results["after"]["event_ids"]), get_prev_content=True ) - return { - "events_before": events_before, - "events_after": events_after, - "start": results["before"]["token"], - "end": results["after"]["token"], - } + return _EventsAround( + events_before=events_before, + events_after=events_after, + start=results["before"]["token"], + end=results["after"]["token"], + ) def _get_events_around_txn( self, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index c9b220e73d..96ae7790bb 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -577,7 +577,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) - self._find_event_in_chunk(room_timeline["events"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) def test_aggregation_get_event_for_annotation(self): """Test that annotations do not get bundled aggregations included -- cgit 1.5.1 From 2d3bd9aa670eedd299cc03093459929adec41918 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 26 Jan 2022 14:21:13 +0000 Subject: Add a module callback to set username at registration (#11790) This is in the context of mainlining the Tchap fork of Synapse. Currently in Tchap usernames are derived from the user's email address (extracted from the UIA results, more specifically the m.login.email.identity step). This change also exports the check_username method from the registration handler as part of the module API, so that a module can check if the username it's trying to generate is correct and doesn't conflict with an existing one, and fallback gracefully if not. Co-authored-by: David Robertson --- changelog.d/11790.feature | 1 + docs/modules/password_auth_provider_callbacks.md | 62 +++++++++++++++++++ synapse/handlers/auth.py | 58 +++++++++++++++++ synapse/module_api/__init__.py | 22 +++++++ synapse/rest/client/register.py | 12 +++- tests/handlers/test_password_providers.py | 79 +++++++++++++++++++++++- 6 files changed, 231 insertions(+), 3 deletions(-) create mode 100644 changelog.d/11790.feature (limited to 'synapse/rest') diff --git a/changelog.d/11790.feature b/changelog.d/11790.feature new file mode 100644 index 0000000000..4a5cc8ec37 --- /dev/null +++ b/changelog.d/11790.feature @@ -0,0 +1 @@ +Add a module callback to set username at registration. diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md index e53abf6409..ec8324d292 100644 --- a/docs/modules/password_auth_provider_callbacks.md +++ b/docs/modules/password_auth_provider_callbacks.md @@ -105,6 +105,68 @@ device ID), and the (now deactivated) access token. If multiple modules implement this callback, Synapse runs them all in order. +### `get_username_for_registration` + +_First introduced in Synapse v1.52.0_ + +```python +async def get_username_for_registration( + uia_results: Dict[str, Any], + params: Dict[str, Any], +) -> Optional[str] +``` + +Called when registering a new user. The module can return a username to set for the user +being registered by returning it as a string, or `None` if it doesn't wish to force a +username for this user. If a username is returned, it will be used as the local part of a +user's full Matrix ID (e.g. it's `alice` in `@alice:example.com`). + +This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api) +has been completed by the user. It is not called when registering a user via SSO. It is +passed two dictionaries, which include the information that the user has provided during +the registration process. + +The first dictionary contains the results of the [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api) +flow followed by the user. Its keys are the identifiers of every step involved in the flow, +associated with either a boolean value indicating whether the step was correctly completed, +or additional information (e.g. email address, phone number...). A list of most existing +identifiers can be found in the [Matrix specification](https://spec.matrix.org/v1.1/client-server-api/#authentication-types). +Here's an example featuring all currently supported keys: + +```python +{ + "m.login.dummy": True, # Dummy authentication + "m.login.terms": True, # User has accepted the terms of service for the homeserver + "m.login.recaptcha": True, # User has completed the recaptcha challenge + "m.login.email.identity": { # User has provided and verified an email address + "medium": "email", + "address": "alice@example.com", + "validated_at": 1642701357084, + }, + "m.login.msisdn": { # User has provided and verified a phone number + "medium": "msisdn", + "address": "33123456789", + "validated_at": 1642701357084, + }, + "org.matrix.msc3231.login.registration_token": "sometoken", # User has registered through the flow described in MSC3231 +} +``` + +The second dictionary contains the parameters provided by the user's client in the request +to `/_matrix/client/v3/register`. See the [Matrix specification](https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3register) +for a complete list of these parameters. + +If the module cannot, or does not wish to, generate a username for this user, it must +return `None`. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `None`, Synapse falls through to the next one. The value of the first +callback that does not return `None` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. If every callback return `None`, +the username provided by the user is used, if any (otherwise one is automatically +generated). + + ## Example The example module below implements authentication checkers for two different login types: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index bd1a322563..e32c93e234 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -2060,6 +2060,10 @@ CHECK_AUTH_CALLBACK = Callable[ Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] ], ] +GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[ + [JsonDict, JsonDict], + Awaitable[Optional[str]], +] class PasswordAuthProvider: @@ -2072,6 +2076,9 @@ class PasswordAuthProvider: # lists of callbacks self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = [] self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = [] + self.get_username_for_registration_callbacks: List[ + GET_USERNAME_FOR_REGISTRATION_CALLBACK + ] = [] # Mapping from login type to login parameters self._supported_login_types: Dict[str, Iterable[str]] = {} @@ -2086,6 +2093,9 @@ class PasswordAuthProvider: auth_checkers: Optional[ Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] ] = None, + get_username_for_registration: Optional[ + GET_USERNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: # Register check_3pid_auth callback if check_3pid_auth is not None: @@ -2130,6 +2140,11 @@ class PasswordAuthProvider: # Add the new method to the list of auth_checker_callbacks for this login type self.auth_checker_callbacks.setdefault(login_type, []).append(callback) + if get_username_for_registration is not None: + self.get_username_for_registration_callbacks.append( + get_username_for_registration, + ) + def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: """Get the login types supported by this password provider @@ -2285,3 +2300,46 @@ class PasswordAuthProvider: except Exception as e: logger.warning("Failed to run module API callback %s: %s", callback, e) continue + + async def get_username_for_registration( + self, + uia_results: JsonDict, + params: JsonDict, + ) -> Optional[str]: + """Defines the username to use when registering the user, using the credentials + and parameters provided during the UIA flow. + + Stops at the first callback that returns a string. + + Args: + uia_results: The credentials provided during the UIA flow. + params: The parameters provided by the registration request. + + Returns: + The localpart to use when registering this user, or None if no module + returned a localpart. + """ + for callback in self.get_username_for_registration_callbacks: + try: + res = await callback(uia_results, params) + + if isinstance(res, str): + return res + elif res is not None: + # mypy complains that this line is unreachable because it assumes the + # data returned by the module fits the expected type. We just want + # to make sure this is the case. + logger.warning( # type: ignore[unreachable] + "Ignoring non-string value returned by" + " get_username_for_registration callback %s: %s", + callback, + res, + ) + except Exception as e: + logger.error( + "Module raised an exception in get_username_for_registration: %s", + e, + ) + raise SynapseError(code=500, msg="Internal Server Error") + + return None diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 662e60bc33..788b2e47d5 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -71,6 +71,7 @@ from synapse.handlers.account_validity import ( from synapse.handlers.auth import ( CHECK_3PID_AUTH_CALLBACK, CHECK_AUTH_CALLBACK, + GET_USERNAME_FOR_REGISTRATION_CALLBACK, ON_LOGGED_OUT_CALLBACK, AuthHandler, ) @@ -177,6 +178,7 @@ class ModuleApi: self._presence_stream = hs.get_event_sources().sources.presence self._state = hs.get_state_handler() self._clock: Clock = hs.get_clock() + self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() self.custom_template_dir = hs.config.server.custom_template_directory @@ -310,6 +312,9 @@ class ModuleApi: auth_checkers: Optional[ Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] ] = None, + get_username_for_registration: Optional[ + GET_USERNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: """Registers callbacks for password auth provider capabilities. @@ -319,6 +324,7 @@ class ModuleApi: check_3pid_auth=check_3pid_auth, on_logged_out=on_logged_out, auth_checkers=auth_checkers, + get_username_for_registration=get_username_for_registration, ) def register_background_update_controller_callbacks( @@ -1202,6 +1208,22 @@ class ModuleApi: """ return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs) + async def check_username(self, username: str) -> None: + """Checks if the provided username uses the grammar defined in the Matrix + specification, and is already being used by an existing user. + + Added in Synapse v1.52.0. + + Args: + username: The username to check. This is the local part of the user's full + Matrix user ID, i.e. it's "alice" if the full user ID is "@alice:foo.com". + + Raises: + SynapseError with the errcode "M_USER_IN_USE" if the username is already in + use. + """ + await self._registration_handler.check_username(username) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index c59dae7c03..e3492f9f93 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -425,6 +425,7 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() + self.password_auth_provider = hs.get_password_auth_provider() self._registration_enabled = self.hs.config.registration.enable_registration self._refresh_tokens_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None @@ -638,7 +639,16 @@ class RegisterRestServlet(RestServlet): if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - desired_username = params.get("username", None) + desired_username = await ( + self.password_auth_provider.get_username_for_registration( + auth_result, + params, + ) + ) + + if desired_username is None: + desired_username = params.get("username", None) + guest_access_token = params.get("guest_access_token", None) if desired_username is not None: diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 2add72b28a..94809cb8be 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -20,10 +20,11 @@ from unittest.mock import Mock from twisted.internet import defer import synapse +from synapse.api.constants import LoginType from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.module_api import ModuleApi -from synapse.rest.client import devices, login, logout -from synapse.types import JsonDict +from synapse.rest.client import devices, login, logout, register +from synapse.types import JsonDict, UserID from tests import unittest from tests.server import FakeChannel @@ -156,6 +157,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): login.register_servlets, devices.register_servlets, logout.register_servlets, + register.register_servlets, ] def setUp(self): @@ -745,6 +747,79 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): on_logged_out.assert_called_once() self.assertTrue(self.called) + def test_username(self): + """Tests that the get_username_for_registration callback can define the username + of a user when registering. + """ + self._setup_get_username_for_registration() + + username = "rin" + channel = self.make_request( + "POST", + "/register", + { + "username": username, + "password": "bar", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(channel.code, 200) + + # Our callback takes the username and appends "-foo" to it, check that's what we + # have. + mxid = channel.json_body["user_id"] + self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") + + def test_username_uia(self): + """Tests that the get_username_for_registration callback is only called at the + end of the UIA flow. + """ + m = self._setup_get_username_for_registration() + + # Initiate the UIA flow. + username = "rin" + channel = self.make_request( + "POST", + "register", + {"username": username, "type": "m.login.password", "password": "bar"}, + ) + self.assertEqual(channel.code, 401) + self.assertIn("session", channel.json_body) + + # Check that the callback hasn't been called yet. + m.assert_not_called() + + # Finish the UIA flow. + session = channel.json_body["session"] + channel = self.make_request( + "POST", + "register", + {"auth": {"session": session, "type": LoginType.DUMMY}}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + mxid = channel.json_body["user_id"] + self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") + + # Check that the callback has been called. + m.assert_called_once() + + def _setup_get_username_for_registration(self) -> Mock: + """Registers a get_username_for_registration callback that appends "-foo" to the + username the client is trying to register. + """ + + async def get_username_for_registration(uia_results, params): + self.assertIn(LoginType.DUMMY, uia_results) + username = params["username"] + return username + "-foo" + + m = Mock(side_effect=get_username_for_registration) + + password_auth_provider = self.hs.get_password_auth_provider() + password_auth_provider.get_username_for_registration_callbacks.append(m) + + return m + def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) -- cgit 1.5.1