summary refs log tree commit diff
path: root/synapse/rest/media/download_resource.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/media/download_resource.py')
-rw-r--r--synapse/rest/media/download_resource.py40
1 files changed, 26 insertions, 14 deletions
diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py
index 3c618ef60a..65b9ff52fa 100644
--- a/synapse/rest/media/download_resource.py
+++ b/synapse/rest/media/download_resource.py
@@ -13,16 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING
+import re
+from typing import TYPE_CHECKING, Optional
 
-from synapse.http.server import (
-    DirectServeJsonResource,
-    set_corp_headers,
-    set_cors_headers,
-)
-from synapse.http.servlet import parse_boolean
+from synapse.http.server import set_corp_headers, set_cors_headers
+from synapse.http.servlet import RestServlet, parse_boolean
 from synapse.http.site import SynapseRequest
-from synapse.media._base import parse_media_id, respond_404
+from synapse.media._base import respond_404
+from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
     from synapse.media.media_repository import MediaRepository
@@ -31,15 +29,28 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class DownloadResource(DirectServeJsonResource):
-    isLeaf = True
+class DownloadResource(RestServlet):
+    PATTERNS = [
+        re.compile(
+            "/_matrix/media/(r0|v3|v1)/download/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)(/(?P<file_name>[^/]*))?$"
+        )
+    ]
 
     def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
         super().__init__()
         self.media_repo = media_repo
         self._is_mine_server_name = hs.is_mine_server_name
 
-    async def _async_render_GET(self, request: SynapseRequest) -> None:
+    async def on_GET(
+        self,
+        request: SynapseRequest,
+        server_name: str,
+        media_id: str,
+        file_name: Optional[str] = None,
+    ) -> None:
+        # Validate the server name, raising if invalid
+        parse_and_validate_server_name(server_name)
+
         set_cors_headers(request)
         set_corp_headers(request)
         request.setHeader(
@@ -58,9 +69,8 @@ class DownloadResource(DirectServeJsonResource):
             b"Referrer-Policy",
             b"no-referrer",
         )
-        server_name, media_id, name = parse_media_id(request)
         if self._is_mine_server_name(server_name):
-            await self.media_repo.get_local_media(request, media_id, name)
+            await self.media_repo.get_local_media(request, media_id, file_name)
         else:
             allow_remote = parse_boolean(request, "allow_remote", default=True)
             if not allow_remote:
@@ -72,4 +82,6 @@ class DownloadResource(DirectServeJsonResource):
                 respond_404(request)
                 return
 
-            await self.media_repo.get_remote_media(request, server_name, media_id, name)
+            await self.media_repo.get_remote_media(
+                request, server_name, media_id, file_name
+            )