diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index c82b4f87d6..8720b1401f 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -15,6 +15,9 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
@@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
assert_requester_is_admin,
assert_user_is_admin,
)
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, room_id: str):
+ async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, user_id: str):
+ async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, server_name: str, media_id: str):
+ async def on_POST(
+ self, request: Request, server_name: str, media_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
return 200, {}
+class ProtectMediaByID(RestServlet):
+ """Protect local media from being quarantined.
+ """
+
+ PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
+
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Protecting local media by ID: %s", media_id)
+
+ # Quarantine this media id
+ await self.store.mark_local_media_as_safe(media_id)
+
+ return 200, {}
+
+
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
@@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_media_cache")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
- async def on_DELETE(self, request, server_name: str, media_id: str):
+ async def on_DELETE(
+ self, request: Request, server_name: str, media_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if self.server_name != server_name:
@@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
- async def on_POST(self, request, server_name: str):
+ async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
return 200, {"deleted_media": deleted_media, "total": total}
-def register_servlets_for_media_repo(hs, http_server):
+def register_servlets_for_media_repo(hs: "HomeServer", http_server):
"""
Media repo specific APIs.
"""
@@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
QuarantineMediaInRoom(hs).register(http_server)
QuarantineMediaByID(hs).register(http_server)
QuarantineMediaByUser(hs).register(http_server)
+ ProtectMediaByID(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
DeleteMediaByID(hs).register(http_server)
DeleteMediaByDateSize(hs).register(http_server)
|