summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11767.bugfix1
-rw-r--r--synapse/rest/media/v1/preview_html.py31
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py224
-rw-r--r--tests/rest/media/v1/test_html_preview.py (renamed from tests/test_preview.py)34
-rw-r--r--tests/rest/media/v1/test_url_preview.py81
-rw-r--r--tests/server.py2
6 files changed, 299 insertions, 74 deletions
diff --git a/changelog.d/11767.bugfix b/changelog.d/11767.bugfix
new file mode 100644
index 0000000000..3e344747f4
--- /dev/null
+++ b/changelog.d/11767.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug when previewing Reddit URLs which do not contain an image.
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py
index 30b067dd42..872a9e72e8 100644
--- a/synapse/rest/media/v1/preview_html.py
+++ b/synapse/rest/media/v1/preview_html.py
@@ -321,14 +321,33 @@ def _iterate_over_text(
 
 
 def rebase_url(url: str, base: str) -> str:
-    base_parts = list(urlparse.urlparse(base))
+    """
+    Resolves a potentially relative `url` against an absolute `base` URL.
+
+    For example:
+
+        >>> rebase_url("subpage", "https://example.com/foo/")
+        'https://example.com/foo/subpage'
+        >>> rebase_url("sibling", "https://example.com/foo")
+        'https://example.com/sibling'
+        >>> rebase_url("/bar", "https://example.com/foo/")
+        'https://example.com/bar'
+        >>> rebase_url("https://alice.com/a/", "https://example.com/foo/")
+        'https://alice.com/a'
+    """
+    base_parts = urlparse.urlparse(base)
+    # Convert the parsed URL to a list for (potential) modification.
     url_parts = list(urlparse.urlparse(url))
-    if not url_parts[0]:  # fix up schema
-        url_parts[0] = base_parts[0] or "http"
-    if not url_parts[1]:  # fix up hostname
-        url_parts[1] = base_parts[1]
+    # Add a scheme, if one does not exist.
+    if not url_parts[0]:
+        url_parts[0] = base_parts.scheme or "http"
+    # Fix up the hostname, if this is not a data URL.
+    if url_parts[0] != "data" and not url_parts[1]:
+        url_parts[1] = base_parts.netloc
+        # If the path does not start with a /, nest it under the base path's last
+        # directory.
         if not url_parts[2].startswith("/"):
-            url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
+            url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2]
     return urlparse.urlunparse(url_parts)
 
 
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index e8881bc870..efd84ced8f 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -21,8 +21,9 @@ import re
 import shutil
 import sys
 import traceback
-from typing import TYPE_CHECKING, Iterable, Optional, Tuple
+from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple
 from urllib import parse as urlparse
+from urllib.request import urlopen
 
 import attr
 
@@ -71,6 +72,17 @@ IMAGE_CACHE_EXPIRY_MS = 2 * ONE_DAY
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
+class DownloadResult:
+    length: int
+    uri: str
+    response_code: int
+    media_type: str
+    download_name: Optional[str]
+    expires: int
+    etag: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class MediaInfo:
     """
     Information parsed from downloading media being previewed.
@@ -256,7 +268,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         if oembed_url:
             url_to_download = oembed_url
 
-        media_info = await self._download_url(url_to_download, user)
+        media_info = await self._handle_url(url_to_download, user)
 
         logger.debug("got media_info of '%s'", media_info)
 
@@ -297,7 +309,9 @@ class PreviewUrlResource(DirectServeJsonResource):
                 oembed_url = self._oembed.autodiscover_from_html(tree)
                 og_from_oembed: JsonDict = {}
                 if oembed_url:
-                    oembed_info = await self._download_url(oembed_url, user)
+                    oembed_info = await self._handle_url(
+                        oembed_url, user, allow_data_urls=True
+                    )
                     (
                         og_from_oembed,
                         author_name,
@@ -367,7 +381,135 @@ class PreviewUrlResource(DirectServeJsonResource):
 
         return jsonog.encode("utf8")
 
-    async def _download_url(self, url: str, user: UserID) -> MediaInfo:
+    async def _download_url(self, url: str, output_stream: BinaryIO) -> DownloadResult:
+        """
+        Fetches a remote URL and parses the headers.
+
+        Args:
+             url: The URL to fetch.
+             output_stream: The stream to write the content to.
+
+        Returns:
+            A tuple of:
+                Media length, URL downloaded, the HTTP response code,
+                the media type, the downloaded file name, the number of
+                milliseconds the result is valid for, the etag header.
+        """
+
+        try:
+            logger.debug("Trying to get preview for url '%s'", url)
+            length, headers, uri, code = await self.client.get_file(
+                url,
+                output_stream=output_stream,
+                max_size=self.max_spider_size,
+                headers={"Accept-Language": self.url_preview_accept_language},
+            )
+        except SynapseError:
+            # Pass SynapseErrors through directly, so that the servlet
+            # handler will return a SynapseError to the client instead of
+            # blank data or a 500.
+            raise
+        except DNSLookupError:
+            # DNS lookup returned no results
+            # Note: This will also be the case if one of the resolved IP
+            # addresses is blacklisted
+            raise SynapseError(
+                502,
+                "DNS resolution failure during URL preview generation",
+                Codes.UNKNOWN,
+            )
+        except Exception as e:
+            # FIXME: pass through 404s and other error messages nicely
+            logger.warning("Error downloading %s: %r", url, e)
+
+            raise SynapseError(
+                500,
+                "Failed to download content: %s"
+                % (traceback.format_exception_only(sys.exc_info()[0], e),),
+                Codes.UNKNOWN,
+            )
+
+        if b"Content-Type" in headers:
+            media_type = headers[b"Content-Type"][0].decode("ascii")
+        else:
+            media_type = "application/octet-stream"
+
+        download_name = get_filename_from_headers(headers)
+
+        # FIXME: we should calculate a proper expiration based on the
+        # Cache-Control and Expire headers.  But for now, assume 1 hour.
+        expires = ONE_HOUR
+        etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
+
+        return DownloadResult(
+            length, uri, code, media_type, download_name, expires, etag
+        )
+
+    async def _parse_data_url(
+        self, url: str, output_stream: BinaryIO
+    ) -> DownloadResult:
+        """
+        Parses a data: URL.
+
+        Args:
+             url: The URL to parse.
+             output_stream: The stream to write the content to.
+
+        Returns:
+            A tuple of:
+                Media length, URL downloaded, the HTTP response code,
+                the media type, the downloaded file name, the number of
+                milliseconds the result is valid for, the etag header.
+        """
+
+        try:
+            logger.debug("Trying to parse data url '%s'", url)
+            with urlopen(url) as url_info:
+                # TODO Can this be more efficient.
+                output_stream.write(url_info.read())
+        except Exception as e:
+            logger.warning("Error parsing data: URL %s: %r", url, e)
+
+            raise SynapseError(
+                500,
+                "Failed to parse data URL: %s"
+                % (traceback.format_exception_only(sys.exc_info()[0], e),),
+                Codes.UNKNOWN,
+            )
+
+        return DownloadResult(
+            # Read back the length that has been written.
+            length=output_stream.tell(),
+            uri=url,
+            # If it was parsed, consider this a 200 OK.
+            response_code=200,
+            # urlopen shoves the media-type from the data URL into the content type
+            # header object.
+            media_type=url_info.headers.get_content_type(),
+            # Some features are not supported by data: URLs.
+            download_name=None,
+            expires=ONE_HOUR,
+            etag=None,
+        )
+
+    async def _handle_url(
+        self, url: str, user: UserID, allow_data_urls: bool = False
+    ) -> MediaInfo:
+        """
+        Fetches content from a URL and parses the result to generate a MediaInfo.
+
+        It uses the media storage provider to persist the fetched content and
+        stores the mapping into the database.
+
+        Args:
+             url: The URL to fetch.
+             user: The user who ahs requested this URL.
+             allow_data_urls: True if data URLs should be allowed.
+
+        Returns:
+            A MediaInfo object describing the fetched content.
+        """
+
         # TODO: we should probably honour robots.txt... except in practice
         # we're most likely being explicitly triggered by a human rather than a
         # bot, so are we really a robot?
@@ -377,61 +519,27 @@ class PreviewUrlResource(DirectServeJsonResource):
         file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
 
         with self.media_storage.store_into_file(file_info) as (f, fname, finish):
-            try:
-                logger.debug("Trying to get preview for url '%s'", url)
-                length, headers, uri, code = await self.client.get_file(
-                    url,
-                    output_stream=f,
-                    max_size=self.max_spider_size,
-                    headers={"Accept-Language": self.url_preview_accept_language},
-                )
-            except SynapseError:
-                # Pass SynapseErrors through directly, so that the servlet
-                # handler will return a SynapseError to the client instead of
-                # blank data or a 500.
-                raise
-            except DNSLookupError:
-                # DNS lookup returned no results
-                # Note: This will also be the case if one of the resolved IP
-                # addresses is blacklisted
-                raise SynapseError(
-                    502,
-                    "DNS resolution failure during URL preview generation",
-                    Codes.UNKNOWN,
-                )
-            except Exception as e:
-                # FIXME: pass through 404s and other error messages nicely
-                logger.warning("Error downloading %s: %r", url, e)
-
-                raise SynapseError(
-                    500,
-                    "Failed to download content: %s"
-                    % (traceback.format_exception_only(sys.exc_info()[0], e),),
-                    Codes.UNKNOWN,
-                )
-            await finish()
+            if url.startswith("data:"):
+                if not allow_data_urls:
+                    raise SynapseError(
+                        500, "Previewing of data: URLs is forbidden", Codes.UNKNOWN
+                    )
 
-            if b"Content-Type" in headers:
-                media_type = headers[b"Content-Type"][0].decode("ascii")
+                download_result = await self._parse_data_url(url, f)
             else:
-                media_type = "application/octet-stream"
+                download_result = await self._download_url(url, f)
 
-            download_name = get_filename_from_headers(headers)
-
-            # FIXME: we should calculate a proper expiration based on the
-            # Cache-Control and Expire headers.  But for now, assume 1 hour.
-            expires = ONE_HOUR
-            etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
+            await finish()
 
         try:
             time_now_ms = self.clock.time_msec()
 
             await self.store.store_local_media(
                 media_id=file_id,
-                media_type=media_type,
+                media_type=download_result.media_type,
                 time_now_ms=time_now_ms,
-                upload_name=download_name,
-                media_length=length,
+                upload_name=download_result.download_name,
+                media_length=download_result.length,
                 user_id=user,
                 url_cache=url,
             )
@@ -444,16 +552,16 @@ class PreviewUrlResource(DirectServeJsonResource):
             raise
 
         return MediaInfo(
-            media_type=media_type,
-            media_length=length,
-            download_name=download_name,
+            media_type=download_result.media_type,
+            media_length=download_result.length,
+            download_name=download_result.download_name,
             created_ts_ms=time_now_ms,
             filesystem_id=file_id,
             filename=fname,
-            uri=uri,
-            response_code=code,
-            expires=expires,
-            etag=etag,
+            uri=download_result.uri,
+            response_code=download_result.response_code,
+            expires=download_result.expires,
+            etag=download_result.etag,
         )
 
     async def _precache_image_url(
@@ -474,8 +582,8 @@ class PreviewUrlResource(DirectServeJsonResource):
         # FIXME: it might be cleaner to use the same flow as the main /preview_url
         # request itself and benefit from the same caching etc.  But for now we
         # just rely on the caching on the master request to speed things up.
-        image_info = await self._download_url(
-            rebase_url(og["og:image"], media_info.uri), user
+        image_info = await self._handle_url(
+            rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True
         )
 
         if _is_media(image_info.media_type):
diff --git a/tests/test_preview.py b/tests/rest/media/v1/test_html_preview.py
index 46e02f483f..a4b57e3d1f 100644
--- a/tests/test_preview.py
+++ b/tests/rest/media/v1/test_html_preview.py
@@ -16,10 +16,11 @@ from synapse.rest.media.v1.preview_html import (
     _get_html_media_encodings,
     decode_body,
     parse_html_to_open_graph,
+    rebase_url,
     summarize_paragraphs,
 )
 
-from . import unittest
+from tests import unittest
 
 try:
     import lxml
@@ -447,3 +448,34 @@ class MediaEncodingTestCase(unittest.TestCase):
             'text/html; charset="invalid"',
         )
         self.assertEqual(list(encodings), ["utf-8", "cp1252"])
+
+
+class RebaseUrlTestCase(unittest.TestCase):
+    def test_relative(self):
+        """Relative URLs should be resolved based on the context of the base URL."""
+        self.assertEqual(
+            rebase_url("subpage", "https://example.com/foo/"),
+            "https://example.com/foo/subpage",
+        )
+        self.assertEqual(
+            rebase_url("sibling", "https://example.com/foo"),
+            "https://example.com/sibling",
+        )
+        self.assertEqual(
+            rebase_url("/bar", "https://example.com/foo/"),
+            "https://example.com/bar",
+        )
+
+    def test_absolute(self):
+        """Absolute URLs should not be modified."""
+        self.assertEqual(
+            rebase_url("https://alice.com/a/", "https://example.com/foo/"),
+            "https://alice.com/a/",
+        )
+
+    def test_data(self):
+        """Data URLs should not be modified."""
+        self.assertEqual(
+            rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
+            "data:,Hello%2C%20World%21",
+        )
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 16e904f15b..53f6186213 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -12,9 +12,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import base64
 import json
 import os
 import re
+from urllib.parse import urlencode
 
 from twisted.internet._resolver import HostResolution
 from twisted.internet.address import IPv4Address, IPv6Address
@@ -23,6 +25,7 @@ from twisted.test.proto_helpers import AccumulatingProtocol
 
 from synapse.config.oembed import OEmbedEndpointConfig
 from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
+from synapse.types import JsonDict
 from synapse.util.stringutils import parse_and_validate_mxc_uri
 
 from tests import unittest
@@ -142,6 +145,14 @@ class URLPreviewTests(unittest.HomeserverTestCase):
     def create_test_resource(self):
         return self.hs.get_media_repository_resource()
 
+    def _assert_small_png(self, json_body: JsonDict) -> None:
+        """Assert properties from the SMALL_PNG test image."""
+        self.assertTrue(json_body["og:image"].startswith("mxc://"))
+        self.assertEqual(json_body["og:image:height"], 1)
+        self.assertEqual(json_body["og:image:width"], 1)
+        self.assertEqual(json_body["og:image:type"], "image/png")
+        self.assertEqual(json_body["matrix:image:size"], 67)
+
     def test_cache_returns_correct_type(self):
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
@@ -569,6 +580,66 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             server.data,
         )
 
+    def test_data_url(self):
+        """
+        Requesting to preview a data URL is not supported.
+        """
+        self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+        data = base64.b64encode(SMALL_PNG).decode()
+
+        query_params = urlencode(
+            {
+                "url": f'<html><head><img src="data:image/png;base64,{data}" /></head></html>'
+            }
+        )
+
+        channel = self.make_request(
+            "GET",
+            f"preview_url?{query_params}",
+            shorthand=False,
+        )
+        self.pump()
+
+        self.assertEqual(channel.code, 500)
+
+    def test_inline_data_url(self):
+        """
+        An inline image (as a data URL) should be parsed properly.
+        """
+        self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+        data = base64.b64encode(SMALL_PNG)
+
+        end_content = (
+            b"<html><head>" b'<img src="data:image/png;base64,%s" />' b"</head></html>"
+        ) % (data,)
+
+        channel = self.make_request(
+            "GET",
+            "preview_url?url=http://matrix.org",
+            shorthand=False,
+            await_result=False,
+        )
+        self.pump()
+
+        client = self.reactor.tcpClients[0][2].buildProtocol(None)
+        server = AccumulatingProtocol()
+        server.makeConnection(FakeTransport(client, self.reactor))
+        client.makeConnection(FakeTransport(server, self.reactor))
+        client.dataReceived(
+            (
+                b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+                b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+            )
+            % (len(end_content),)
+            + end_content
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+        self._assert_small_png(channel.json_body)
+
     def test_oembed_photo(self):
         """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
         self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@@ -626,10 +697,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         body = channel.json_body
         self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
-        self.assertTrue(body["og:image"].startswith("mxc://"))
-        self.assertEqual(body["og:image:height"], 1)
-        self.assertEqual(body["og:image:width"], 1)
-        self.assertEqual(body["og:image:type"], "image/png")
+        self._assert_small_png(body)
 
     def test_oembed_rich(self):
         """Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
@@ -820,10 +888,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(
             body["og:url"], "http://www.twitter.com/matrixdotorg/status/12345"
         )
-        self.assertTrue(body["og:image"].startswith("mxc://"))
-        self.assertEqual(body["og:image:height"], 1)
-        self.assertEqual(body["og:image:width"], 1)
-        self.assertEqual(body["og:image:type"], "image/png")
+        self._assert_small_png(body)
 
     def _download_image(self):
         """Downloads an image into the URL cache.
diff --git a/tests/server.py b/tests/server.py
index a0cd14ea45..82990c2eb9 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -313,7 +313,7 @@ def make_request(
     req = request(channel, site)
     req.content = BytesIO(content)
     # Twisted expects to be at the end of the content when parsing the request.
-    req.content.seek(SEEK_END)
+    req.content.seek(0, SEEK_END)
 
     if access_token:
         req.requestHeaders.addRawHeader(