diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/rest/media/v1/test_url_preview.py | 43 |
1 files changed, 41 insertions, 2 deletions
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 5148c39874..3b24d0ace6 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -17,7 +17,7 @@ import json import os import re from typing import Any, Dict, Optional, Sequence, Tuple, Type -from urllib.parse import urlencode +from urllib.parse import quote, urlencode from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address @@ -69,7 +69,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): "2001:800::/21", ) config["url_preview_ip_range_whitelist"] = ("1.1.1.1",) - config["url_preview_url_blacklist"] = [] config["url_preview_accept_language"] = [ "en-UK", "en-US;q=0.9", @@ -1123,3 +1122,43 @@ class URLPreviewTests(unittest.HomeserverTestCase): os.path.exists(path), f"{os.path.relpath(path, self.media_store_path)} was not deleted", ) + + @unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]}) + def test_blacklist_port(self) -> None: + """Tests that blacklisting URLs with a port makes previewing such URLs + fail with a 403 error and doesn't impact other previews. + """ + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + bad_url = quote("http://matrix.org:8888/foo") + good_url = quote("http://matrix.org/foo") + + channel = self.make_request( + "GET", + "preview_url?url=" + bad_url, + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual(channel.code, 403, channel.result) + + channel = self.make_request( + "GET", + "preview_url?url=" + good_url, + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) |