From 5ea29d7f850b6d2acbbfaf2e81bc5f0625411320 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 24 Jul 2020 09:39:02 -0400 Subject: Convert more of the media code to async/await (#7873) --- synapse/rest/media/v1/_base.py | 15 +++++---- synapse/rest/media/v1/media_storage.py | 60 +++++++++++++++++++--------------- 2 files changed, 42 insertions(+), 33 deletions(-) (limited to 'synapse/rest/media/v1') diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 595849f9d5..9a847130c0 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -18,7 +18,6 @@ import logging import os import urllib -from twisted.internet import defer from twisted.protocols.basic import FileSender from synapse.api.errors import Codes, SynapseError, cs_error @@ -77,8 +76,9 @@ def respond_404(request): ) -@defer.inlineCallbacks -def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None): +async def respond_with_file( + request, media_type, file_path, file_size=None, upload_name=None +): logger.debug("Responding with %r", file_path) if os.path.isfile(file_path): @@ -89,7 +89,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam add_file_headers(request, media_type, file_size, upload_name) with open(file_path, "rb") as f: - yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) + await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) finish_request(request) else: @@ -198,8 +198,9 @@ def _can_encode_filename_as_token(x): return True -@defer.inlineCallbacks -def respond_with_responder(request, responder, media_type, file_size, upload_name=None): +async def respond_with_responder( + request, responder, media_type, file_size, upload_name=None +): """Responds to the request with given responder. If responder is None then returns 404. @@ -218,7 +219,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam add_file_headers(request, media_type, file_size, upload_name) try: with responder: - yield responder.write_to_consumer(request) + await responder.write_to_consumer(request) except Exception as e: # The majority of the time this will be due to the client having gone # away. Unfortunately, Twisted simply throws a generic exception at us diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 79cb0dddbe..66bc1c3360 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -14,17 +14,18 @@ # limitations under the License. import contextlib +import inspect import logging import os import shutil +from typing import Optional -from twisted.internet import defer from twisted.protocols.basic import FileSender from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.util.file_consumer import BackgroundFileConsumer -from ._base import Responder +from ._base import FileInfo, Responder logger = logging.getLogger(__name__) @@ -46,25 +47,24 @@ class MediaStorage(object): self.filepaths = filepaths self.storage_providers = storage_providers - @defer.inlineCallbacks - def store_file(self, source, file_info): + async def store_file(self, source, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other configured storage providers Args: source: A file like object that should be written - file_info (FileInfo): Info about the file to store + file_info: Info about the file to store Returns: - Deferred[str]: the file path written to in the primary media store + the file path written to in the primary media store """ with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - yield defer_to_thread( + await defer_to_thread( self.hs.get_reactor(), _write_file_synchronously, source, f ) - yield finish_cb() + await finish_cb() return fname @@ -75,7 +75,7 @@ class MediaStorage(object): Actually yields a 3-tuple (file, fname, finish_cb), where file is a file like object that can be written to, fname is the absolute path of file - on disk, and finish_cb is a function that returns a Deferred. + on disk, and finish_cb is a function that returns an awaitable. fname can be used to read the contents from after upload, e.g. to generate thumbnails. @@ -91,7 +91,7 @@ class MediaStorage(object): with media_storage.store_into_file(info) as (f, fname, finish_cb): # .. write into f ... - yield finish_cb() + await finish_cb() """ path = self._file_info_to_path(file_info) @@ -103,10 +103,13 @@ class MediaStorage(object): finished_called = [False] - @defer.inlineCallbacks - def finish(): + async def finish(): for provider in self.storage_providers: - yield provider.store_file(path, file_info) + # store_file is supposed to return an Awaitable, but guard + # against improper implementations. + result = provider.store_file(path, file_info) + if inspect.isawaitable(result): + await result finished_called[0] = True @@ -123,17 +126,15 @@ class MediaStorage(object): if not finished_called: raise Exception("Finished callback not called") - @defer.inlineCallbacks - def fetch_media(self, file_info): + async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: """Attempts to fetch media described by file_info from the local cache and configured storage providers. Args: - file_info (FileInfo) + file_info Returns: - Deferred[Responder|None]: Returns a Responder if the file was found, - otherwise None. + Returns a Responder if the file was found, otherwise None. """ path = self._file_info_to_path(file_info) @@ -142,23 +143,26 @@ class MediaStorage(object): return FileResponder(open(local_path, "rb")) for provider in self.storage_providers: - res = yield provider.fetch(path, file_info) + res = provider.fetch(path, file_info) + # Fetch is supposed to return an Awaitable, but guard against + # improper implementations. + if inspect.isawaitable(res): + res = await res if res: logger.debug("Streaming %s from %s", path, provider) return res return None - @defer.inlineCallbacks - def ensure_media_is_in_local_cache(self, file_info): + async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: """Ensures that the given file is in the local cache. Attempts to download it from storage providers if it isn't. Args: - file_info (FileInfo) + file_info Returns: - Deferred[str]: Full path to local file + Full path to local file """ path = self._file_info_to_path(file_info) local_path = os.path.join(self.local_media_directory, path) @@ -170,14 +174,18 @@ class MediaStorage(object): os.makedirs(dirname) for provider in self.storage_providers: - res = yield provider.fetch(path, file_info) + res = provider.fetch(path, file_info) + # Fetch is supposed to return an Awaitable, but guard against + # improper implementations. + if inspect.isawaitable(res): + res = await res if res: with res: consumer = BackgroundFileConsumer( open(local_path, "wb"), self.hs.get_reactor() ) - yield res.write_to_consumer(consumer) - yield consumer.wait() + await res.write_to_consumer(consumer) + await consumer.wait() return local_path raise Exception("file could not be found") -- cgit 1.5.1 From 3fc8fdd150e2471d6e96b842e364d9421066f4ba Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 07:50:44 -0400 Subject: Support oEmbed for media previews. (#7920) Fixes previews of Twitter URLs by using their oEmbed endpoint to grab content. --- changelog.d/7920.feature | 1 + synapse/rest/media/v1/preview_url_resource.py | 265 +++++++++++++++++++++----- tests/rest/media/v1/test_url_preview.py | 142 +++++++++++++- 3 files changed, 355 insertions(+), 53 deletions(-) create mode 100644 changelog.d/7920.feature (limited to 'synapse/rest/media/v1') diff --git a/changelog.d/7920.feature b/changelog.d/7920.feature new file mode 100644 index 0000000000..4093f5d329 --- /dev/null +++ b/changelog.d/7920.feature @@ -0,0 +1 @@ +Support oEmbed for media previews. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index e52c86c798..13d1a6d2ed 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -26,6 +26,7 @@ import traceback from typing import Dict, Optional from urllib import parse as urlparse +import attr from canonicaljson import json from twisted.internet import defer @@ -56,6 +57,65 @@ _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) OG_TAG_NAME_MAXLEN = 50 OG_TAG_VALUE_MAXLEN = 1000 +ONE_HOUR = 60 * 60 * 1000 + +# A map of globs to API endpoints. +_oembed_globs = { + # Twitter. + "https://publish.twitter.com/oembed": [ + "https://twitter.com/*/status/*", + "https://*.twitter.com/*/status/*", + "https://twitter.com/*/moments/*", + "https://*.twitter.com/*/moments/*", + # Include the HTTP versions too. + "http://twitter.com/*/status/*", + "http://*.twitter.com/*/status/*", + "http://twitter.com/*/moments/*", + "http://*.twitter.com/*/moments/*", + ], +} +# Convert the globs to regular expressions. +_oembed_patterns = {} +for endpoint, globs in _oembed_globs.items(): + for glob in globs: + # Convert the glob into a sane regular expression to match against. The + # rules followed will be slightly different for the domain portion vs. + # the rest. + # + # 1. The scheme must be one of HTTP / HTTPS (and have no globs). + # 2. The domain can have globs, but we limit it to characters that can + # reasonably be a domain part. + # TODO: This does not attempt to handle Unicode domain names. + # 3. Other parts allow a glob to be any one, or more, characters. + results = urlparse.urlparse(glob) + + # Ensure the scheme does not have wildcards (and is a sane scheme). + if results.scheme not in {"http", "https"}: + raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,)) + + pattern = urlparse.urlunparse( + [ + results.scheme, + re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"), + ] + + [re.escape(part).replace("\\*", ".+") for part in results[2:]] + ) + _oembed_patterns[re.compile(pattern)] = endpoint + + +@attr.s +class OEmbedResult: + # Either HTML content or URL must be provided. + html = attr.ib(type=Optional[str]) + url = attr.ib(type=Optional[str]) + title = attr.ib(type=Optional[str]) + # Number of seconds to cache the content. + cache_age = attr.ib(type=int) + + +class OEmbedError(Exception): + """An error occurred processing the oEmbed object.""" + class PreviewUrlResource(DirectServeJsonResource): isLeaf = True @@ -99,7 +159,7 @@ class PreviewUrlResource(DirectServeJsonResource): cache_name="url_previews", clock=self.clock, # don't spider URLs more often than once an hour - expiry_ms=60 * 60 * 1000, + expiry_ms=ONE_HOUR, ) if self._worker_run_media_background_jobs: @@ -310,6 +370,87 @@ class PreviewUrlResource(DirectServeJsonResource): return jsonog.encode("utf8") + def _get_oembed_url(self, url: str) -> Optional[str]: + """ + Check whether the URL should be downloaded as oEmbed content instead. + + Params: + url: The URL to check. + + Returns: + A URL to use instead or None if the original URL should be used. + """ + for url_pattern, endpoint in _oembed_patterns.items(): + if url_pattern.fullmatch(url): + return endpoint + + # No match. + return None + + async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult: + """ + Request content from an oEmbed endpoint. + + Params: + endpoint: The oEmbed API endpoint. + url: The URL to pass to the API. + + Returns: + An object representing the metadata returned. + + Raises: + OEmbedError if fetching or parsing of the oEmbed information fails. + """ + try: + logger.debug("Trying to get oEmbed content for url '%s'", url) + result = await self.client.get_json( + endpoint, + # TODO Specify max height / width. + # Note that only the JSON format is supported. + args={"url": url}, + ) + + # Ensure there's a version of 1.0. + if result.get("version") != "1.0": + raise OEmbedError("Invalid version: %s" % (result.get("version"),)) + + oembed_type = result.get("type") + + # Ensure the cache age is None or an int. + cache_age = result.get("cache_age") + if cache_age: + cache_age = int(cache_age) + + oembed_result = OEmbedResult(None, None, result.get("title"), cache_age) + + # HTML content. + if oembed_type == "rich": + oembed_result.html = result.get("html") + return oembed_result + + if oembed_type == "photo": + oembed_result.url = result.get("url") + return oembed_result + + # TODO Handle link and video types. + + if "thumbnail_url" in result: + oembed_result.url = result.get("thumbnail_url") + return oembed_result + + raise OEmbedError("Incompatible oEmbed information.") + + except OEmbedError as e: + # Trap OEmbedErrors first so we can directly re-raise them. + logger.warning("Error parsing oEmbed metadata from %s: %r", url, e) + raise + + except Exception as e: + # Trap any exception and let the code follow as usual. + # FIXME: pass through 404s and other error messages nicely + logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) + raise OEmbedError() from e + async def _download_url(self, url, user): # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a @@ -319,54 +460,90 @@ class PreviewUrlResource(DirectServeJsonResource): file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) - with self.media_storage.store_into_file(file_info) as (f, fname, finish): + # If this URL can be accessed via oEmbed, use that instead. + url_to_download = url + oembed_url = self._get_oembed_url(url) + if oembed_url: + # The result might be a new URL to download, or it might be HTML content. try: - logger.debug("Trying to get preview for url '%s'", url) - length, headers, uri, code = await self.client.get_file( - url, - output_stream=f, - max_size=self.max_spider_size, - headers={"Accept-Language": self.url_preview_accept_language}, - ) - except SynapseError: - # Pass SynapseErrors through directly, so that the servlet - # handler will return a SynapseError to the client instead of - # blank data or a 500. - raise - except DNSLookupError: - # DNS lookup returned no results - # Note: This will also be the case if one of the resolved IP - # addresses is blacklisted - raise SynapseError( - 502, - "DNS resolution failure during URL preview generation", - Codes.UNKNOWN, - ) - except Exception as e: - # FIXME: pass through 404s and other error messages nicely - logger.warning("Error downloading %s: %r", url, e) + oembed_result = await self._get_oembed_content(oembed_url, url) + if oembed_result.url: + url_to_download = oembed_result.url + elif oembed_result.html: + url_to_download = None + except OEmbedError: + # If an error occurs, try doing a normal preview. + pass - raise SynapseError( - 500, - "Failed to download content: %s" - % (traceback.format_exception_only(sys.exc_info()[0], e),), - Codes.UNKNOWN, - ) - await finish() + if url_to_download: + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + try: + logger.debug("Trying to get preview for url '%s'", url_to_download) + length, headers, uri, code = await self.client.get_file( + url_to_download, + output_stream=f, + max_size=self.max_spider_size, + headers={"Accept-Language": self.url_preview_accept_language}, + ) + except SynapseError: + # Pass SynapseErrors through directly, so that the servlet + # handler will return a SynapseError to the client instead of + # blank data or a 500. + raise + except DNSLookupError: + # DNS lookup returned no results + # Note: This will also be the case if one of the resolved IP + # addresses is blacklisted + raise SynapseError( + 502, + "DNS resolution failure during URL preview generation", + Codes.UNKNOWN, + ) + except Exception as e: + # FIXME: pass through 404s and other error messages nicely + logger.warning("Error downloading %s: %r", url_to_download, e) + + raise SynapseError( + 500, + "Failed to download content: %s" + % (traceback.format_exception_only(sys.exc_info()[0], e),), + Codes.UNKNOWN, + ) + await finish() + + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode("ascii") + else: + media_type = "application/octet-stream" + + download_name = get_filename_from_headers(headers) + + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + expires = ONE_HOUR + etag = headers["ETag"][0] if "ETag" in headers else None + else: + html_bytes = oembed_result.html.encode("utf-8") # type: ignore + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + f.write(html_bytes) + await finish() + + media_type = "text/html" + download_name = oembed_result.title + length = len(html_bytes) + # If a specific cache age was not given, assume 1 hour. + expires = oembed_result.cache_age or ONE_HOUR + uri = oembed_url + code = 200 + etag = None try: - if b"Content-Type" in headers: - media_type = headers[b"Content-Type"][0].decode("ascii") - else: - media_type = "application/octet-stream" time_now_ms = self.clock.time_msec() - download_name = get_filename_from_headers(headers) - await self.store.store_local_media( media_id=file_id, media_type=media_type, - time_now_ms=self.clock.time_msec(), + time_now_ms=time_now_ms, upload_name=download_name, media_length=length, user_id=user, @@ -389,10 +566,8 @@ class PreviewUrlResource(DirectServeJsonResource): "filename": fname, "uri": uri, "response_code": code, - # FIXME: we should calculate a proper expiration based on the - # Cache-Control and Expire headers. But for now, assume 1 hour. - "expires": 60 * 60 * 1000, - "etag": headers["ETag"][0] if "ETag" in headers else None, + "expires": expires, + "etag": etag, } def _start_expire_url_cache_data(self): @@ -449,7 +624,7 @@ class PreviewUrlResource(DirectServeJsonResource): # These may be cached for a bit on the client (i.e., they # may have a room open with a preview url thing open). # So we wait a couple of days before deleting, just in case. - expire_before = now - 2 * 24 * 60 * 60 * 1000 + expire_before = now - 2 * 24 * ONE_HOUR media_ids = await self.store.get_url_cache_media_before(expire_before) removed_media = [] diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 2826211f32..74765a582b 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -12,8 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import os +import re + +from mock import patch import attr @@ -131,7 +134,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.reactor.nameResolver = Resolver() def test_cache_returns_correct_type(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] request, channel = self.make_request( "GET", "url_preview?url=http://matrix.org", shorthand=False @@ -187,7 +190,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) def test_non_ascii_preview_httpequiv(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( b"" @@ -221,7 +224,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") def test_non_ascii_preview_content_type(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( b"" @@ -254,7 +257,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") def test_overlong_title(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( b"" @@ -292,7 +295,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ IP addresses can be previewed directly. """ - self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] + self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] request, channel = self.make_request( "GET", "url_preview?url=http://example.com", shorthand=False @@ -439,7 +442,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Hardcode the URL resolving to the IP we want. self.lookups["example.com"] = [ (IPv4Address, "1.1.1.2"), - (IPv4Address, "8.8.8.8"), + (IPv4Address, "10.1.2.3"), ] request, channel = self.make_request( @@ -518,7 +521,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ Accept-Language header is sent to the remote server """ - self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] + self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] # Build and make a request to the server request, channel = self.make_request( @@ -562,3 +565,126 @@ class URLPreviewTests(unittest.HomeserverTestCase): ), server.data, ) + + def test_oembed_photo(self): + """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL.""" + # Route the HTTP version to an HTTP endpoint so that the tests work. + with patch.dict( + "synapse.rest.media.v1.preview_url_resource._oembed_patterns", + { + re.compile( + r"http://twitter\.com/.+/status/.+" + ): "http://publish.twitter.com/oembed", + }, + clear=True, + ): + + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] + self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] + + result = { + "version": "1.0", + "type": "photo", + "url": "http://cdn.twitter.com/matrixdotorg", + } + oembed_content = json.dumps(result).encode("utf-8") + + end_content = ( + b"" + b"Some Title" + b'' + b"" + ) + + request, channel = self.make_request( + "GET", + "url_preview?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: application/json; charset="utf8"\r\n\r\n' + ) + % (len(oembed_content),) + + oembed_content + ) + + self.pump() + + client = self.reactor.tcpClients[1][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "Some Title", "og:description": "hi"} + ) + + def test_oembed_rich(self): + """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" + # Route the HTTP version to an HTTP endpoint so that the tests work. + with patch.dict( + "synapse.rest.media.v1.preview_url_resource._oembed_patterns", + { + re.compile( + r"http://twitter\.com/.+/status/.+" + ): "http://publish.twitter.com/oembed", + }, + clear=True, + ): + + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] + + result = { + "version": "1.0", + "type": "rich", + "html": "
Content Preview
", + } + end_content = json.dumps(result).encode("utf-8") + + request, channel = self.make_request( + "GET", + "url_preview?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: application/json; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, + {"og:title": None, "og:description": "Content Preview"}, + ) -- cgit 1.5.1