summary refs log tree commit diff
path: root/tests/rest/media/v1/test_url_preview.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/media/v1/test_url_preview.py')
-rw-r--r--tests/rest/media/v1/test_url_preview.py43
1 files changed, 41 insertions, 2 deletions
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 5148c39874..3b24d0ace6 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -17,7 +17,7 @@ import json
 import os
 import re
 from typing import Any, Dict, Optional, Sequence, Tuple, Type
-from urllib.parse import urlencode
+from urllib.parse import quote, urlencode
 
 from twisted.internet._resolver import HostResolution
 from twisted.internet.address import IPv4Address, IPv6Address
@@ -69,7 +69,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             "2001:800::/21",
         )
         config["url_preview_ip_range_whitelist"] = ("1.1.1.1",)
-        config["url_preview_url_blacklist"] = []
         config["url_preview_accept_language"] = [
             "en-UK",
             "en-US;q=0.9",
@@ -1123,3 +1122,43 @@ class URLPreviewTests(unittest.HomeserverTestCase):
                 os.path.exists(path),
                 f"{os.path.relpath(path, self.media_store_path)} was not deleted",
             )
+
+    @unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]})
+    def test_blacklist_port(self) -> None:
+        """Tests that blacklisting URLs with a port makes previewing such URLs
+        fail with a 403 error and doesn't impact other previews.
+        """
+        self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+        bad_url = quote("http://matrix.org:8888/foo")
+        good_url = quote("http://matrix.org/foo")
+
+        channel = self.make_request(
+            "GET",
+            "preview_url?url=" + bad_url,
+            shorthand=False,
+            await_result=False,
+        )
+        self.pump()
+        self.assertEqual(channel.code, 403, channel.result)
+
+        channel = self.make_request(
+            "GET",
+            "preview_url?url=" + good_url,
+            shorthand=False,
+            await_result=False,
+        )
+        self.pump()
+
+        client = self.reactor.tcpClients[0][2].buildProtocol(None)
+        server = AccumulatingProtocol()
+        server.makeConnection(FakeTransport(client, self.reactor))
+        client.makeConnection(FakeTransport(server, self.reactor))
+        client.dataReceived(
+            b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+            % (len(self.end_content),)
+            + self.end_content
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)