summary refs log tree commit diff
path: root/tests/rest/media/test_url_preview.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/media/test_url_preview.py')
-rw-r--r--tests/rest/media/test_url_preview.py240
1 files changed, 210 insertions, 30 deletions
diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py

index e44beae8c1..170fb0534a 100644 --- a/tests/rest/media/test_url_preview.py +++ b/tests/rest/media/test_url_preview.py
@@ -418,9 +418,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - def test_blacklisted_ip_specific(self) -> None: + def test_blocked_ip_specific(self) -> None: """ - Blacklisted IP addresses, found via DNS, are not spidered. + Blocked IP addresses, found via DNS, are not spidered. """ self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")] @@ -439,9 +439,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ip_range(self) -> None: + def test_blocked_ip_range(self) -> None: """ - Blacklisted IP ranges, IPs found over DNS, are not spidered. + Blocked IP ranges, IPs found over DNS, are not spidered. """ self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")] @@ -458,9 +458,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ip_specific_direct(self) -> None: + def test_blocked_ip_specific_direct(self) -> None: """ - Blacklisted IP addresses, accessed directly, are not spidered. + Blocked IP addresses, accessed directly, are not spidered. """ channel = self.make_request( "GET", "preview_url?url=http://192.168.1.1", shorthand=False @@ -470,16 +470,13 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(len(self.reactor.tcpClients), 0) self.assertEqual( channel.json_body, - { - "errcode": "M_UNKNOWN", - "error": "IP address blocked by IP blacklist entry", - }, + {"errcode": "M_UNKNOWN", "error": "IP address blocked"}, ) self.assertEqual(channel.code, 403) - def test_blacklisted_ip_range_direct(self) -> None: + def test_blocked_ip_range_direct(self) -> None: """ - Blacklisted IP ranges, accessed directly, are not spidered. + Blocked IP ranges, accessed directly, are not spidered. """ channel = self.make_request( "GET", "preview_url?url=http://1.1.1.2", shorthand=False @@ -488,15 +485,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 403) self.assertEqual( channel.json_body, - { - "errcode": "M_UNKNOWN", - "error": "IP address blocked by IP blacklist entry", - }, + {"errcode": "M_UNKNOWN", "error": "IP address blocked"}, ) - def test_blacklisted_ip_range_whitelisted_ip(self) -> None: + def test_blocked_ip_range_whitelisted_ip(self) -> None: """ - Blacklisted but then subsequently whitelisted IP addresses can be + Blocked but then subsequently whitelisted IP addresses can be spidered. """ self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")] @@ -527,10 +521,10 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} ) - def test_blacklisted_ip_with_external_ip(self) -> None: + def test_blocked_ip_with_external_ip(self) -> None: """ - If a hostname resolves a blacklisted IP, even if there's a - non-blacklisted one, it will be rejected. + If a hostname resolves a blocked IP, even if there's a non-blocked one, + it will be rejected. """ # Hardcode the URL resolving to the IP we want. self.lookups["example.com"] = [ @@ -550,9 +544,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ipv6_specific(self) -> None: + def test_blocked_ipv6_specific(self) -> None: """ - Blacklisted IP addresses, found via DNS, are not spidered. + Blocked IP addresses, found via DNS, are not spidered. """ self.lookups["example.com"] = [ (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff") @@ -573,9 +567,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) - def test_blacklisted_ipv6_range(self) -> None: + def test_blocked_ipv6_range(self) -> None: """ - Blacklisted IP ranges, IPs found over DNS, are not spidered. + Blocked IP ranges, IPs found over DNS, are not spidered. """ self.lookups["example.com"] = [(IPv6Address, "2001:800::1")] @@ -653,6 +647,57 @@ class URLPreviewTests(unittest.HomeserverTestCase): server.data, ) + def test_image(self) -> None: + """An image should be precached if mentioned in the HTML.""" + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")] + + result = ( + b"""<html><body><img src="http://cdn.matrix.org/foo.png"></body></html>""" + ) + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + # Respond with the HTML. + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(result),) + + result + ) + self.pump() + + # Respond with the photo. + client = self.reactor.tcpClients[1][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b"Content-Type: image/png\r\n\r\n" + ) + % (len(SMALL_PNG),) + + SMALL_PNG + ) + self.pump() + + # The image should be in the result. + self.assertEqual(channel.code, 200) + self._assert_small_png(channel.json_body) + def test_nonexistent_image(self) -> None: """If the preview image doesn't exist, ensure some data is returned.""" self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] @@ -683,9 +728,53 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) self.pump() + + # There should not be a second connection. + self.assertEqual(len(self.reactor.tcpClients), 1) + + # The image should not be in the result. self.assertEqual(channel.code, 200) + self.assertNotIn("og:image", channel.json_body) + + @unittest.override_config( + {"url_preview_url_blacklist": [{"netloc": "cdn.matrix.org"}]} + ) + def test_image_blocked(self) -> None: + """If the preview image doesn't exist, ensure some data is returned.""" + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")] + + result = ( + b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>""" + ) + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(result),) + + result + ) + self.pump() + + # There should not be a second connection. + self.assertEqual(len(self.reactor.tcpClients), 1) # The image should not be in the result. + self.assertEqual(channel.code, 200) self.assertNotIn("og:image", channel.json_body) def test_oembed_failure(self) -> None: @@ -880,6 +969,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) self.pump() + + # Double check that the proper host is being connected to. (Note that + # twitter.com can't be resolved so this is already implicitly checked.) + self.assertIn(b"\r\nHost: publish.twitter.com\r\n", server.data) + self.assertEqual(channel.code, 200) body = channel.json_body self.assertEqual( @@ -940,6 +1034,22 @@ class URLPreviewTests(unittest.HomeserverTestCase): }, ) + @unittest.override_config( + {"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]} + ) + def test_oembed_blocked(self) -> None: + """The oEmbed URL should not be downloaded if the oEmbed URL is blocked.""" + self.lookups["twitter.com"] = [(IPv4Address, "10.1.2.3")] + + channel = self.make_request( + "GET", + "preview_url?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual(channel.code, 403, channel.result) + def test_oembed_autodiscovery(self) -> None: """ Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL. @@ -980,7 +1090,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): % (len(result),) + result ) - self.pump() # The oEmbed response. @@ -1004,7 +1113,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): % (len(oembed_content),) + oembed_content ) - self.pump() # Ensure the URL is what was requested. @@ -1023,7 +1131,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): % (len(SMALL_PNG),) + SMALL_PNG ) - self.pump() # Ensure the URL is what was requested. @@ -1036,6 +1143,59 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) self._assert_small_png(body) + @unittest.override_config( + {"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]} + ) + def test_oembed_autodiscovery_blocked(self) -> None: + """ + If the discovered oEmbed URL is blocked, it should be discarded. + """ + # This is a little cheesy in that we use the www subdomain (which isn't the + # list of oEmbed patterns) to get "raw" HTML response. + self.lookups["www.twitter.com"] = [(IPv4Address, "10.1.2.3")] + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.4")] + + result = b""" + <title>Test</title> + <link rel="alternate" type="application/json+oembed" + href="http://publish.twitter.com/oembed?url=http%3A%2F%2Fcdn.twitter.com%2Fmatrixdotorg%2Fstatus%2F12345&format=json" + title="matrixdotorg" /> + """ + + channel = self.make_request( + "GET", + "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(result),) + + result + ) + + self.pump() + + # Ensure there's no additional connections. + self.assertEqual(len(self.reactor.tcpClients), 1) + + # Ensure the URL is what was requested. + self.assertIn(b"\r\nHost: www.twitter.com\r\n", server.data) + + self.assertEqual(channel.code, 200) + body = channel.json_body + self.assertEqual(body["og:title"], "Test") + self.assertNotIn("og:image", body) + def _download_image(self) -> Tuple[str, str]: """Downloads an image into the URL cache. Returns: @@ -1192,8 +1352,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) @unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]}) - def test_blacklist_port(self) -> None: - """Tests that blacklisting URLs with a port makes previewing such URLs + def test_blocked_port(self) -> None: + """Tests that blocking URLs with a port makes previewing such URLs fail with a 403 error and doesn't impact other previews. """ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] @@ -1230,3 +1390,23 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.pump() self.assertEqual(channel.code, 200) + + @unittest.override_config( + {"url_preview_url_blacklist": [{"netloc": "example.com"}]} + ) + def test_blocked_url(self) -> None: + """Tests that blocking URLs with a host makes previewing such URLs + fail with a 403 error. + """ + self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] + + bad_url = quote("http://example.com/foo") + + channel = self.make_request( + "GET", + "preview_url?url=" + bad_url, + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual(channel.code, 403, channel.result)