summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10759.feature1
-rw-r--r--synapse/config/oembed.py24
-rw-r--r--synapse/rest/media/v1/oembed.py26
-rw-r--r--tests/rest/media/v1/test_url_preview.py55
4 files changed, 98 insertions, 8 deletions
diff --git a/changelog.d/10759.feature b/changelog.d/10759.feature
new file mode 100644
index 0000000000..7d18f5c133
--- /dev/null
+++ b/changelog.d/10759.feature
@@ -0,0 +1 @@
+Allow configuration of the oEmbed URLs used for URL previews.
diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py
index 09267b5eef..ea6ace4729 100644
--- a/synapse/config/oembed.py
+++ b/synapse/config/oembed.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import json
 import re
-from typing import Any, Dict, Iterable, List, Pattern
+from typing import Any, Dict, Iterable, List, Optional, Pattern
 from urllib import parse as urlparse
 
 import attr
@@ -31,6 +31,8 @@ class OEmbedEndpointConfig:
     api_endpoint: str
     # The patterns to match.
     url_patterns: List[Pattern]
+    # The supported formats.
+    formats: Optional[List[str]]
 
 
 class OembedConfig(Config):
@@ -93,11 +95,22 @@ class OembedConfig(Config):
             # might have multiple patterns to match.
             for endpoint in provider["endpoints"]:
                 api_endpoint = endpoint["url"]
+
+                # The API endpoint must be an HTTP(S) URL.
+                results = urlparse.urlparse(api_endpoint)
+                if results.scheme not in {"http", "https"}:
+                    raise ConfigError(
+                        f"Unsupported oEmbed scheme ({results.scheme}) for endpoint {api_endpoint}",
+                        config_path,
+                    )
+
                 patterns = [
                     self._glob_to_pattern(glob, config_path)
                     for glob in endpoint["schemes"]
                 ]
-                yield OEmbedEndpointConfig(api_endpoint, patterns)
+                yield OEmbedEndpointConfig(
+                    api_endpoint, patterns, endpoint.get("formats")
+                )
 
     def _glob_to_pattern(self, glob: str, config_path: Iterable[str]) -> Pattern:
         """
@@ -114,9 +127,12 @@ class OembedConfig(Config):
         """
         results = urlparse.urlparse(glob)
 
-        # Ensure the scheme does not have wildcards (and is a sane scheme).
+        # The scheme must be HTTP(S) (and cannot contain wildcards).
         if results.scheme not in {"http", "https"}:
-            raise ConfigError(f"Insecure oEmbed scheme: {results.scheme}", config_path)
+            raise ConfigError(
+                f"Unsupported oEmbed scheme ({results.scheme}) for pattern: {glob}",
+                config_path,
+            )
 
         pattern = urlparse.urlunparse(
             [
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index afe41823e4..2e6706dbfa 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -49,8 +49,24 @@ class OEmbedProvider:
     def __init__(self, hs: "HomeServer", client: SimpleHttpClient):
         self._oembed_patterns = {}
         for oembed_endpoint in hs.config.oembed.oembed_patterns:
+            api_endpoint = oembed_endpoint.api_endpoint
+
+            # Only JSON is supported at the moment. This could be declared in
+            # the formats field. Otherwise, if the endpoint ends in .xml assume
+            # it doesn't support JSON.
+            if (
+                oembed_endpoint.formats is not None
+                and "json" not in oembed_endpoint.formats
+            ) or api_endpoint.endswith(".xml"):
+                logger.info(
+                    "Ignoring oEmbed endpoint due to not supporting JSON: %s",
+                    api_endpoint,
+                )
+                continue
+
+            # Iterate through each URL pattern and point it to the endpoint.
             for pattern in oembed_endpoint.url_patterns:
-                self._oembed_patterns[pattern] = oembed_endpoint.api_endpoint
+                self._oembed_patterns[pattern] = api_endpoint
         self._client = client
 
     def get_oembed_url(self, url: str) -> Optional[str]:
@@ -86,11 +102,15 @@ class OEmbedProvider:
         """
         try:
             logger.debug("Trying to get oEmbed content for url '%s'", url)
+
+            # Note that only the JSON format is supported, some endpoints want
+            # this in the URL, others want it as an argument.
+            endpoint = endpoint.replace("{format}", "json")
+
             result = await self._client.get_json(
                 endpoint,
                 # TODO Specify max height / width.
-                # Note that only the JSON format is supported.
-                args={"url": url},
+                args={"url": url, "format": "json"},
             )
 
             # Ensure there's a version of 1.0.
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 7fa9027227..9f6fbfe6de 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -92,7 +92,15 @@ class URLPreviewTests(unittest.HomeserverTestCase):
                 url_patterns=[
                     re.compile(r"http://twitter\.com/.+/status/.+"),
                 ],
-            )
+                formats=None,
+            ),
+            OEmbedEndpointConfig(
+                api_endpoint="http://www.hulu.com/api/oembed.{format}",
+                url_patterns=[
+                    re.compile(r"http://www\.hulu\.com/watch/.+"),
+                ],
+                formats=["json"],
+            ),
         ]
 
         return hs
@@ -656,3 +664,48 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             channel.json_body,
             {"og:title": None, "og:description": "Content Preview"},
         )
+
+    def test_oembed_format(self):
+        """Test an oEmbed endpoint which requires the format in the URL."""
+        self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")]
+
+        result = {
+            "version": "1.0",
+            "type": "rich",
+            "html": "<div>Content Preview</div>",
+        }
+        end_content = json.dumps(result).encode("utf-8")
+
+        channel = self.make_request(
+            "GET",
+            "preview_url?url=http://www.hulu.com/watch/12345",
+            shorthand=False,
+            await_result=False,
+        )
+        self.pump()
+
+        client = self.reactor.tcpClients[0][2].buildProtocol(None)
+        server = AccumulatingProtocol()
+        server.makeConnection(FakeTransport(client, self.reactor))
+        client.makeConnection(FakeTransport(server, self.reactor))
+        client.dataReceived(
+            (
+                b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+                b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+            )
+            % (len(end_content),)
+            + end_content
+        )
+
+        self.pump()
+
+        # The {format} should have been turned into json.
+        self.assertIn(b"/api/oembed.json", server.data)
+        # A URL parameter of format=json should be provided.
+        self.assertIn(b"format=json", server.data)
+
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body,
+            {"og:title": None, "og:description": "Content Preview"},
+        )