diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index d673adba16..56217fccdf 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -51,11 +51,15 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream
-from synapse.storage._base import db_to_json, make_in_list_sql_clause
+from synapse.storage._base import (
+ db_to_json,
+ make_in_list_sql_clause,
+)
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
+ make_tuple_in_list_sql_clause,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.types import Cursor
@@ -1127,6 +1131,109 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
return local_media_ids
+ def _quarantine_local_media_txn(
+ self,
+ txn: LoggingTransaction,
+ hashes: Set[str],
+ media_ids: Set[str],
+ quarantined_by: Optional[str],
+ ) -> int:
+ """Quarantine and unquarantine local media items.
+
+ Args:
+ txn (cursor)
+ hashes: A set of sha256 hashes for any media that should be quarantined
+ media_ids: A set of media IDs for any media that should be quarantined
+ quarantined_by: The ID of the user who initiated the quarantine request
+ If it is `None` media will be removed from quarantine
+ Returns:
+ The total number of media items quarantined
+ """
+ total_media_quarantined = 0
+
+ # Effectively a legacy path, update any media that was explicitly named.
+ if media_ids:
+ sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
+ txn.database_engine, "media_id", media_ids
+ )
+ sql = f"""
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE {sql_many_clause_sql}"""
+
+ if quarantined_by is not None:
+ sql += " AND safe_from_quarantine = FALSE"
+
+ txn.execute(sql, [quarantined_by] + sql_many_clause_args)
+ # Note that a rowcount of -1 can be used to indicate no rows were affected.
+ total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
+
+ # Update any media that was identified via hash.
+ if hashes:
+ sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
+ txn.database_engine, "sha256", hashes
+ )
+ sql = f"""
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE {sql_many_clause_sql}"""
+
+ if quarantined_by is not None:
+ sql += " AND safe_from_quarantine = FALSE"
+
+ txn.execute(sql, [quarantined_by] + sql_many_clause_args)
+ total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
+
+ return total_media_quarantined
+
+ def _quarantine_remote_media_txn(
+ self,
+ txn: LoggingTransaction,
+ hashes: Set[str],
+ media: Set[Tuple[str, str]],
+ quarantined_by: Optional[str],
+ ) -> int:
+ """Quarantine and unquarantine remote items
+
+ Args:
+ txn (cursor)
+ hashes: A set of sha256 hashes for any media that should be quarantined
+ media_ids: A set of tuples (media_origin, media_id) for any media that should be quarantined
+ quarantined_by: The ID of the user who initiated the quarantine request
+ If it is `None` media will be removed from quarantine
+ Returns:
+ The total number of media items quarantined
+ """
+ total_media_quarantined = 0
+
+ if media:
+ sql_in_list_clause, sql_args = make_tuple_in_list_sql_clause(
+ txn.database_engine,
+ ("media_origin", "media_id"),
+ media,
+ )
+ sql = f"""
+ UPDATE remote_media_cache
+ SET quarantined_by = ?
+ WHERE {sql_in_list_clause}"""
+
+ txn.execute(sql, [quarantined_by] + sql_args)
+ total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
+
+ total_media_quarantined = 0
+ if hashes:
+ sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
+ txn.database_engine, "sha256", hashes
+ )
+ sql = f"""
+ UPDATE remote_media_cache
+ SET quarantined_by = ?
+ WHERE {sql_many_clause_sql}"""
+ txn.execute(sql, [quarantined_by] + sql_many_clause_args)
+ total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
+
+ return total_media_quarantined
+
def _quarantine_media_txn(
self,
txn: LoggingTransaction,
@@ -1146,40 +1253,49 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
Returns:
The total number of media items quarantined
"""
-
- # Update all the tables to set the quarantined_by flag
- sql = """
- UPDATE local_media_repository
- SET quarantined_by = ?
- WHERE media_id = ?
- """
-
- # set quarantine
- if quarantined_by is not None:
- sql += "AND safe_from_quarantine = FALSE"
- txn.executemany(
- sql, [(quarantined_by, media_id) for media_id in local_mxcs]
+ hashes = set()
+ media_ids = set()
+ remote_media = set()
+
+ # First, determine the hashes of the media we want to delete.
+ # We also want the media_ids for any media that lacks a hash.
+ if local_mxcs:
+ hash_sql_many_clause_sql, hash_sql_many_clause_args = (
+ make_in_list_sql_clause(txn.database_engine, "media_id", local_mxcs)
)
- # remove from quarantine
- else:
- txn.executemany(
- sql, [(quarantined_by, media_id) for media_id in local_mxcs]
+ hash_sql = f"SELECT sha256, media_id FROM local_media_repository WHERE {hash_sql_many_clause_sql}"
+ if quarantined_by is not None:
+ hash_sql += " AND safe_from_quarantine = FALSE"
+
+ txn.execute(hash_sql, hash_sql_many_clause_args)
+ for sha256, media_id in txn:
+ if sha256:
+ hashes.add(sha256)
+ else:
+ media_ids.add(media_id)
+
+ # Do the same for remote media
+ if remote_mxcs:
+ hash_sql_in_list_clause, hash_sql_args = make_tuple_in_list_sql_clause(
+ txn.database_engine,
+ ("media_origin", "media_id"),
+ remote_mxcs,
)
- # Note that a rowcount of -1 can be used to indicate no rows were affected.
- total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
+ hash_sql = f"SELECT sha256, media_origin, media_id FROM remote_media_cache WHERE {hash_sql_in_list_clause}"
+ txn.execute(hash_sql, hash_sql_args)
+ for sha256, media_origin, media_id in txn:
+ if sha256:
+ hashes.add(sha256)
+ else:
+ remote_media.add((media_origin, media_id))
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin = ? AND media_id = ?
- """,
- [(quarantined_by, origin, media_id) for origin, media_id in remote_mxcs],
+ count = self._quarantine_local_media_txn(txn, hashes, media_ids, quarantined_by)
+ count += self._quarantine_remote_media_txn(
+ txn, hashes, remote_media, quarantined_by
)
- total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
- return total_media_quarantined
+ return count
async def block_room(self, room_id: str, user_id: str) -> None:
"""Marks the room as blocked.
|