diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/media/test_url_previewer.py | 113 | ||||
-rw-r--r-- | tests/rest/media/test_url_preview.py | 194 |
2 files changed, 303 insertions, 4 deletions
diff --git a/tests/media/test_url_previewer.py b/tests/media/test_url_previewer.py new file mode 100644 index 0000000000..3c4c7d6765 --- /dev/null +++ b/tests/media/test_url_previewer.py @@ -0,0 +1,113 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config + +try: + import lxml +except ImportError: + lxml = None + + +class URLPreviewTests(unittest.HomeserverTestCase): + if not lxml: + skip = "url preview feature requires lxml" + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + config["url_preview_enabled"] = True + config["max_spider_size"] = 9999999 + config["url_preview_ip_range_blacklist"] = ( + "192.168.1.1", + "1.0.0.0/8", + "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + "2001:800::/21", + ) + + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + config["media_store_path"] = self.media_store_path + + provider_config = { + "module": "synapse.media.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + + config["media_storage_providers"] = [provider_config] + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + media_repo_resource = hs.get_media_repository_resource() + preview_url = media_repo_resource.children[b"preview_url"] + self.url_previewer = preview_url._url_previewer + + def test_all_urls_allowed(self) -> None: + self.assertFalse(self.url_previewer._is_url_blocked("http://matrix.org")) + self.assertFalse(self.url_previewer._is_url_blocked("https://matrix.org")) + self.assertFalse(self.url_previewer._is_url_blocked("http://localhost:8000")) + self.assertFalse( + self.url_previewer._is_url_blocked("http://user:pass@matrix.org") + ) + + @override_config( + { + "url_preview_url_blacklist": [ + {"username": "user"}, + {"scheme": "http", "netloc": "matrix.org"}, + ] + } + ) + def test_blocked_url(self) -> None: + # Blocked via scheme and URL. + self.assertTrue(self.url_previewer._is_url_blocked("http://matrix.org")) + # Not blocked because all components must match. + self.assertFalse(self.url_previewer._is_url_blocked("https://matrix.org")) + + # Blocked due to the user. + self.assertTrue( + self.url_previewer._is_url_blocked("http://user:pass@example.com") + ) + self.assertTrue(self.url_previewer._is_url_blocked("http://user@example.com")) + + @override_config({"url_preview_url_blacklist": [{"netloc": "*.example.com"}]}) + def test_glob_blocked_url(self) -> None: + # All subdomains are blocked. + self.assertTrue(self.url_previewer._is_url_blocked("http://foo.example.com")) + self.assertTrue(self.url_previewer._is_url_blocked("http://.example.com")) + + # The TLD is not blocked. + self.assertFalse(self.url_previewer._is_url_blocked("https://example.com")) + + @override_config({"url_preview_url_blacklist": [{"netloc": "^.+\\.example\\.com"}]}) + def test_regex_blocked_urL(self) -> None: + # All subdomains are blocked. + self.assertTrue(self.url_previewer._is_url_blocked("http://foo.example.com")) + # Requires a non-empty subdomain. + self.assertFalse(self.url_previewer._is_url_blocked("http://.example.com")) + + # The TLD is not blocked. + self.assertFalse(self.url_previewer._is_url_blocked("https://example.com")) diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py index e44beae8c1..7517155cf3 100644 --- a/tests/rest/media/test_url_preview.py +++ b/tests/rest/media/test_url_preview.py @@ -653,6 +653,57 @@ class URLPreviewTests(unittest.HomeserverTestCase): server.data, ) + def test_image(self) -> None: + """An image should be precached if mentioned in the HTML.""" + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")] + + result = ( + b"""<html><body><img src="http://cdn.matrix.org/foo.png"></body></html>""" + ) + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + # Respond with the HTML. + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(result),) + + result + ) + self.pump() + + # Respond with the photo. + client = self.reactor.tcpClients[1][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b"Content-Type: image/png\r\n\r\n" + ) + % (len(SMALL_PNG),) + + SMALL_PNG + ) + self.pump() + + # The image should be in the result. + self.assertEqual(channel.code, 200) + self._assert_small_png(channel.json_body) + def test_nonexistent_image(self) -> None: """If the preview image doesn't exist, ensure some data is returned.""" self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] @@ -683,9 +734,53 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) self.pump() + + # There should not be a second connection. + self.assertEqual(len(self.reactor.tcpClients), 1) + + # The image should not be in the result. self.assertEqual(channel.code, 200) + self.assertNotIn("og:image", channel.json_body) + + @unittest.override_config( + {"url_preview_url_blacklist": [{"netloc": "cdn.matrix.org"}]} + ) + def test_image_blocked(self) -> None: + """If the preview image doesn't exist, ensure some data is returned.""" + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")] + + result = ( + b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>""" + ) + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(result),) + + result + ) + self.pump() + + # There should not be a second connection. + self.assertEqual(len(self.reactor.tcpClients), 1) # The image should not be in the result. + self.assertEqual(channel.code, 200) self.assertNotIn("og:image", channel.json_body) def test_oembed_failure(self) -> None: @@ -880,6 +975,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) self.pump() + + # Double check that the proper host is being connected to. (Note that + # twitter.com can't be resolved so this is already implicitly checked.) + self.assertIn(b"\r\nHost: publish.twitter.com\r\n", server.data) + self.assertEqual(channel.code, 200) body = channel.json_body self.assertEqual( @@ -940,6 +1040,22 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) + @unittest.override_config( + {"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]} + ) + def test_oembed_blocked(self) -> None: + """The oEmbed URL should not be downloaded if the oEmbed URL is blocked.""" + self.lookups["twitter.com"] = [(IPv4Address, "10.1.2.3")] + + channel = self.make_request( + "GET", + "preview_url?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual(channel.code, 403, channel.result) + def test_oembed_autodiscovery(self) -> None: """ Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL. @@ -980,7 +1096,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): % (len(result),) + result ) - self.pump() # The oEmbed response. @@ -1004,7 +1119,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): % (len(oembed_content),) + oembed_content ) - self.pump() # Ensure the URL is what was requested. @@ -1023,7 +1137,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): % (len(SMALL_PNG),) + SMALL_PNG ) - self.pump() # Ensure the URL is what was requested. @@ -1036,6 +1149,59 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) self._assert_small_png(body) + @unittest.override_config( + {"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]} + ) + def test_oembed_autodiscovery_blocked(self) -> None: + """ + If the discovered oEmbed URL is blocked, it should be discarded. + """ + # This is a little cheesy in that we use the www subdomain (which isn't the + # list of oEmbed patterns) to get "raw" HTML response. + self.lookups["www.twitter.com"] = [(IPv4Address, "10.1.2.3")] + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.4")] + + result = b""" + <title>Test</title> + <link rel="alternate" type="application/json+oembed" + href="http://publish.twitter.com/oembed?url=http%3A%2F%2Fcdn.twitter.com%2Fmatrixdotorg%2Fstatus%2F12345&format=json" + title="matrixdotorg" /> + """ + + channel = self.make_request( + "GET", + "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(result),) + + result + ) + + self.pump() + + # Ensure there's no additional connections. + self.assertEqual(len(self.reactor.tcpClients), 1) + + # Ensure the URL is what was requested. + self.assertIn(b"\r\nHost: www.twitter.com\r\n", server.data) + + self.assertEqual(channel.code, 200) + body = channel.json_body + self.assertEqual(body["og:title"], "Test") + self.assertNotIn("og:image", body) + def _download_image(self) -> Tuple[str, str]: """Downloads an image into the URL cache. Returns: @@ -1192,7 +1358,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) @unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]}) - def test_blacklist_port(self) -> None: + def test_blocked_port(self) -> None: """Tests that blacklisting URLs with a port makes previewing such URLs fail with a 403 error and doesn't impact other previews. """ @@ -1230,3 +1396,23 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.pump() self.assertEqual(channel.code, 200) + + @unittest.override_config( + {"url_preview_url_blacklist": [{"netloc": "example.com"}]} + ) + def test_blocked_url(self) -> None: + """Tests that blacklisting URLs with a host makes previewing such URLs + fail with a 403 error. + """ + self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] + + bad_url = quote("http://example.com/foo") + + channel = self.make_request( + "GET", + "preview_url?url=" + bad_url, + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual(channel.code, 403, channel.result) |