diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index b4e3f052cc..bcf25f298e 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -789,14 +789,15 @@ class RoomWorkerStore(SQLBaseStore):
self,
server_name: str,
media_id: str,
- quarantined_by: str,
+ quarantined_by: Optional[str],
) -> int:
- """quarantines a single local or remote media id
+ """quarantines or unquarantines a single local or remote media id
Args:
server_name: The name of the server that holds this media
media_id: The ID of the media to be quarantined
quarantined_by: The user ID that initiated the quarantine request
+ If it is `None` media will be removed from quarantine
"""
logger.info("Quarantining media: %s/%s", server_name, media_id)
is_local = server_name == self.config.server_name
@@ -863,9 +864,9 @@ class RoomWorkerStore(SQLBaseStore):
txn,
local_mxcs: List[str],
remote_mxcs: List[Tuple[str, str]],
- quarantined_by: str,
+ quarantined_by: Optional[str],
) -> int:
- """Quarantine local and remote media items
+ """Quarantine and unquarantine local and remote media items
Args:
txn (cursor)
@@ -873,18 +874,27 @@ class RoomWorkerStore(SQLBaseStore):
remote_mxcs: A list of (remote server, media id) tuples representing
remote mxc URLs
quarantined_by: The ID of the user who initiated the quarantine request
+ If it is `None` media will be removed from quarantine
Returns:
The total number of media items quarantined
"""
+
# Update all the tables to set the quarantined_by flag
- txn.executemany(
- """
+ sql = """
UPDATE local_media_repository
SET quarantined_by = ?
- WHERE media_id = ? AND safe_from_quarantine = ?
- """,
- ((quarantined_by, media_id, False) for media_id in local_mxcs),
- )
+ WHERE media_id = ?
+ """
+
+ # set quarantine
+ if quarantined_by is not None:
+ sql += "AND safe_from_quarantine = ?"
+ rows = [(quarantined_by, media_id, False) for media_id in local_mxcs]
+ # remove from quarantine
+ else:
+ rows = [(quarantined_by, media_id) for media_id in local_mxcs]
+
+ txn.executemany(sql, rows)
# Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
@@ -1523,7 +1533,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
room_id: str,
event_id: str,
user_id: str,
- reason: str,
+ reason: Optional[str],
content: JsonDict,
received_ts: int,
) -> None:
|