diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 46f643c6b9..13e366536a 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -626,36 +626,10 @@ class RoomWorkerStore(SQLBaseStore):
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
- total_media_quarantined = 0
-
- # Now update all the tables to set the quarantined_by flag
-
- txn.executemany(
- """
- UPDATE local_media_repository
- SET quarantined_by = ?
- WHERE media_id = ?
- """,
- ((quarantined_by, media_id) for media_id in local_mxcs),
- )
-
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin = ? AND media_id = ?
- """,
- (
- (quarantined_by, origin, media_id)
- for origin, media_id in remote_mxcs
- ),
+ return self._quarantine_media_txn(
+ txn, local_mxcs, remote_mxcs, quarantined_by
)
- total_media_quarantined += len(local_mxcs)
- total_media_quarantined += len(remote_mxcs)
-
- return total_media_quarantined
-
return self.db.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@@ -805,17 +779,17 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
The total number of media items quarantined
"""
- total_media_quarantined = 0
-
# Update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
- WHERE media_id = ?
+ WHERE media_id = ? AND safe_from_quarantine = ?
""",
- ((quarantined_by, media_id) for media_id in local_mxcs),
+ ((quarantined_by, media_id, False) for media_id in local_mxcs),
)
+ # Note that a rowcount of -1 can be used to indicate no rows were affected.
+ total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
txn.executemany(
"""
@@ -825,9 +799,7 @@ class RoomWorkerStore(SQLBaseStore):
""",
((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
)
-
- total_media_quarantined += len(local_mxcs)
- total_media_quarantined += len(remote_mxcs)
+ total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
return total_media_quarantined
|