diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index fa833e54cf..3a445d6eed 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -32,16 +32,24 @@ class QuarantineMediaInRoom(RestServlet):
this server.
"""
- PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
+ PATTERNS = (
+ historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+ +
+ # This path kept around for legacy reasons
+ historical_admin_path_patterns("/quarantine_media/(?P<room_id>![^/]+)")
+ )
def __init__(self, hs):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, room_id):
+ async def on_POST(self, request, room_id: str):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
+ logging.info("Quarantining room: %s", room_id)
+
+ # Quarantine all media in this room
num_quarantined = await self.store.quarantine_media_ids_in_room(
room_id, requester.user.to_string()
)
@@ -49,6 +57,60 @@ class QuarantineMediaInRoom(RestServlet):
return 200, {"num_quarantined": num_quarantined}
+class QuarantineMediaByUser(RestServlet):
+ """Quarantines all local media by a given user so that no one can download it via
+ this server.
+ """
+
+ PATTERNS = historical_admin_path_patterns(
+ "/user/(?P<user_id>[^/]+)/media/quarantine"
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, user_id: str):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Quarantining local media by user: %s", user_id)
+
+ # Quarantine all media this user has uploaded
+ num_quarantined = await self.store.quarantine_media_ids_by_user(
+ user_id, requester.user.to_string()
+ )
+
+ return 200, {"num_quarantined": num_quarantined}
+
+
+class QuarantineMediaByID(RestServlet):
+ """Quarantines local or remote media by a given ID so that no one can download
+ it via this server.
+ """
+
+ PATTERNS = historical_admin_path_patterns(
+ "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, server_name: str, media_id: str):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Quarantining local media by ID: %s/%s", server_name, media_id)
+
+ # Quarantine this media id
+ await self.store.quarantine_media_by_id(
+ server_name, media_id, requester.user.to_string()
+ )
+
+ return 200, {}
+
+
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
@@ -94,4 +156,6 @@ def register_servlets_for_media_repo(hs, http_server):
"""
PurgeMediaCacheRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server)
+ QuarantineMediaByID(hs).register(http_server)
+ QuarantineMediaByUser(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 8636d75030..49bab62be3 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -18,7 +18,7 @@ import collections
import logging
import re
from abc import abstractmethod
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
from six import integer_types
@@ -399,6 +399,8 @@ class RoomWorkerStore(SQLBaseStore):
the associated media
"""
+ logger.info("Quarantining media in room: %s", room_id)
+
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
@@ -494,6 +496,118 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
+ def quarantine_media_by_id(
+ self, server_name: str, media_id: str, quarantined_by: str,
+ ):
+ """quarantines a single local or remote media id
+
+ Args:
+ server_name: The name of the server that holds this media
+ media_id: The ID of the media to be quarantined
+ quarantined_by: The user ID that initiated the quarantine request
+ """
+ logger.info("Quarantining media: %s/%s", server_name, media_id)
+ is_local = server_name == self.config.server_name
+
+ def _quarantine_media_by_id_txn(txn):
+ local_mxcs = [media_id] if is_local else []
+ remote_mxcs = [(server_name, media_id)] if not is_local else []
+
+ return self._quarantine_media_txn(
+ txn, local_mxcs, remote_mxcs, quarantined_by
+ )
+
+ return self.db.runInteraction(
+ "quarantine_media_by_user", _quarantine_media_by_id_txn
+ )
+
+ def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+ """quarantines all local media associated with a single user
+
+ Args:
+ user_id: The ID of the user to quarantine media of
+ quarantined_by: The ID of the user who made the quarantine request
+ """
+
+ def _quarantine_media_by_user_txn(txn):
+ local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
+ return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
+
+ return self.db.runInteraction(
+ "quarantine_media_by_user", _quarantine_media_by_user_txn
+ )
+
+ def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
+ """Retrieves local media IDs by a given user
+
+ Args:
+ txn (cursor)
+ user_id: The ID of the user to retrieve media IDs of
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ # Local media
+ sql = """
+ SELECT media_id
+ FROM local_media_repository
+ WHERE user_id = ?
+ """
+ if filter_quarantined:
+ sql += "AND quarantined_by IS NULL"
+ txn.execute(sql, (user_id,))
+
+ local_media_ids = [row[0] for row in txn]
+
+ # TODO: Figure out all remote media a user has referenced in a message
+
+ return local_media_ids
+
+ def _quarantine_media_txn(
+ self,
+ txn,
+ local_mxcs: List[str],
+ remote_mxcs: List[Tuple[str, str]],
+ quarantined_by: str,
+ ) -> int:
+ """Quarantine local and remote media items
+
+ Args:
+ txn (cursor)
+ local_mxcs: A list of local mxc URLs
+ remote_mxcs: A list of (remote server, media id) tuples representing
+ remote mxc URLs
+ quarantined_by: The ID of the user who initiated the quarantine request
+ Returns:
+ The total number of media items quarantined
+ """
+ total_media_quarantined = 0
+
+ # Update all the tables to set the quarantined_by flag
+ txn.executemany(
+ """
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """,
+ ((quarantined_by, media_id) for media_id in local_mxcs),
+ )
+
+ txn.executemany(
+ """
+ UPDATE remote_media_cache
+ SET quarantined_by = ?
+ WHERE media_origin = ? AND media_id = ?
+ """,
+ ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
+ )
+
+ total_media_quarantined += len(local_mxcs)
+ total_media_quarantined += len(remote_mxcs)
+
+ return total_media_quarantined
+
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|