summary refs log tree commit diff
path: root/synapse/storage/databases/main/room.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/room.py')
-rw-r--r--synapse/storage/databases/main/room.py174
1 files changed, 145 insertions, 29 deletions
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py

index d673adba16..56217fccdf 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -51,11 +51,15 @@ from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream -from synapse.storage._base import db_to_json, make_in_list_sql_clause +from synapse.storage._base import ( + db_to_json, + make_in_list_sql_clause, +) from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_tuple_in_list_sql_clause, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.types import Cursor @@ -1127,6 +1131,109 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return local_media_ids + def _quarantine_local_media_txn( + self, + txn: LoggingTransaction, + hashes: Set[str], + media_ids: Set[str], + quarantined_by: Optional[str], + ) -> int: + """Quarantine and unquarantine local media items. + + Args: + txn (cursor) + hashes: A set of sha256 hashes for any media that should be quarantined + media_ids: A set of media IDs for any media that should be quarantined + quarantined_by: The ID of the user who initiated the quarantine request + If it is `None` media will be removed from quarantine + Returns: + The total number of media items quarantined + """ + total_media_quarantined = 0 + + # Effectively a legacy path, update any media that was explicitly named. + if media_ids: + sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause( + txn.database_engine, "media_id", media_ids + ) + sql = f""" + UPDATE local_media_repository + SET quarantined_by = ? + WHERE {sql_many_clause_sql}""" + + if quarantined_by is not None: + sql += " AND safe_from_quarantine = FALSE" + + txn.execute(sql, [quarantined_by] + sql_many_clause_args) + # Note that a rowcount of -1 can be used to indicate no rows were affected. + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + # Update any media that was identified via hash. + if hashes: + sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause( + txn.database_engine, "sha256", hashes + ) + sql = f""" + UPDATE local_media_repository + SET quarantined_by = ? + WHERE {sql_many_clause_sql}""" + + if quarantined_by is not None: + sql += " AND safe_from_quarantine = FALSE" + + txn.execute(sql, [quarantined_by] + sql_many_clause_args) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + return total_media_quarantined + + def _quarantine_remote_media_txn( + self, + txn: LoggingTransaction, + hashes: Set[str], + media: Set[Tuple[str, str]], + quarantined_by: Optional[str], + ) -> int: + """Quarantine and unquarantine remote items + + Args: + txn (cursor) + hashes: A set of sha256 hashes for any media that should be quarantined + media_ids: A set of tuples (media_origin, media_id) for any media that should be quarantined + quarantined_by: The ID of the user who initiated the quarantine request + If it is `None` media will be removed from quarantine + Returns: + The total number of media items quarantined + """ + total_media_quarantined = 0 + + if media: + sql_in_list_clause, sql_args = make_tuple_in_list_sql_clause( + txn.database_engine, + ("media_origin", "media_id"), + media, + ) + sql = f""" + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE {sql_in_list_clause}""" + + txn.execute(sql, [quarantined_by] + sql_args) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + total_media_quarantined = 0 + if hashes: + sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause( + txn.database_engine, "sha256", hashes + ) + sql = f""" + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE {sql_many_clause_sql}""" + txn.execute(sql, [quarantined_by] + sql_many_clause_args) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + return total_media_quarantined + def _quarantine_media_txn( self, txn: LoggingTransaction, @@ -1146,40 +1253,49 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): Returns: The total number of media items quarantined """ - - # Update all the tables to set the quarantined_by flag - sql = """ - UPDATE local_media_repository - SET quarantined_by = ? - WHERE media_id = ? - """ - - # set quarantine - if quarantined_by is not None: - sql += "AND safe_from_quarantine = FALSE" - txn.executemany( - sql, [(quarantined_by, media_id) for media_id in local_mxcs] + hashes = set() + media_ids = set() + remote_media = set() + + # First, determine the hashes of the media we want to delete. + # We also want the media_ids for any media that lacks a hash. + if local_mxcs: + hash_sql_many_clause_sql, hash_sql_many_clause_args = ( + make_in_list_sql_clause(txn.database_engine, "media_id", local_mxcs) ) - # remove from quarantine - else: - txn.executemany( - sql, [(quarantined_by, media_id) for media_id in local_mxcs] + hash_sql = f"SELECT sha256, media_id FROM local_media_repository WHERE {hash_sql_many_clause_sql}" + if quarantined_by is not None: + hash_sql += " AND safe_from_quarantine = FALSE" + + txn.execute(hash_sql, hash_sql_many_clause_args) + for sha256, media_id in txn: + if sha256: + hashes.add(sha256) + else: + media_ids.add(media_id) + + # Do the same for remote media + if remote_mxcs: + hash_sql_in_list_clause, hash_sql_args = make_tuple_in_list_sql_clause( + txn.database_engine, + ("media_origin", "media_id"), + remote_mxcs, ) - # Note that a rowcount of -1 can be used to indicate no rows were affected. - total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 + hash_sql = f"SELECT sha256, media_origin, media_id FROM remote_media_cache WHERE {hash_sql_in_list_clause}" + txn.execute(hash_sql, hash_sql_args) + for sha256, media_origin, media_id in txn: + if sha256: + hashes.add(sha256) + else: + remote_media.add((media_origin, media_id)) - txn.executemany( - """ - UPDATE remote_media_cache - SET quarantined_by = ? - WHERE media_origin = ? AND media_id = ? - """, - [(quarantined_by, origin, media_id) for origin, media_id in remote_mxcs], + count = self._quarantine_local_media_txn(txn, hashes, media_ids, quarantined_by) + count += self._quarantine_remote_media_txn( + txn, hashes, remote_media, quarantined_by ) - total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 - return total_media_quarantined + return count async def block_room(self, room_id: str, user_id: str) -> None: """Marks the room as blocked.