summary refs log tree commit diff
path: root/synapse/storage/databases/main/room.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/databases/main/room.py345
1 files changed, 265 insertions, 80 deletions
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py

index 80a4bf95f2..347dbbba6b 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -51,11 +51,15 @@ from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream -from synapse.storage._base import db_to_json, make_in_list_sql_clause +from synapse.storage._base import ( + db_to_json, + make_in_list_sql_clause, +) from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_tuple_in_list_sql_clause, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.types import Cursor @@ -73,6 +77,8 @@ logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True, auto_attribs=True) class RatelimitOverride: + # n.b. elsewhere in Synapse messages_per_second is represented as a float, but it is + # an integer in the database messages_per_second: int burst_count: int @@ -604,6 +610,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): search_term: Optional[str], public_rooms: Optional[bool], empty_rooms: Optional[bool], + emma_include_tombstone: bool = False, ) -> Tuple[List[Dict[str, Any]], int]: """Function to retrieve a paginated list of rooms as json. @@ -623,6 +630,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): If true, empty rooms are queried. if false, empty rooms are excluded from the query. When it is none (the default), both empty rooms and none-empty rooms are queried. + emma_include_tombstone: If true, include tombstone events in the results. Returns: A list of room dicts and an integer representing the total number of rooms that exist given this query @@ -791,11 +799,43 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): room_count = cast(Tuple[int], txn.fetchone()) return rooms, room_count[0] - return await self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_rooms_paginate", _get_rooms_paginate_txn, ) + if emma_include_tombstone: + room_id_sql, room_id_args = make_in_list_sql_clause( + self.database_engine, "cse.room_id", [r["room_id"] for r in result[0]] + ) + + tombstone_sql = """ + SELECT cse.room_id, cse.event_id, ej.json + FROM current_state_events cse + JOIN event_json ej USING (event_id) + WHERE cse.type = 'm.room.tombstone' + AND {room_id_sql} + """.format( + room_id_sql=room_id_sql + ) + + def _get_tombstones_txn( + txn: LoggingTransaction, + ) -> Dict[str, JsonDict]: + txn.execute(tombstone_sql, room_id_args) + for room_id, event_id, json in txn: + for result_room in result[0]: + if result_room["room_id"] == room_id: + result_room["gay.rory.synapse_admin_extensions.tombstone"] = db_to_json(json) + break + return result[0], result[1] + + result = await self.db_pool.runInteraction( + "get_rooms_tombstones", _get_tombstones_txn, + ) + + return result + @cached(max_entries=10000) async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]: """Check if there are any overrides for ratelimiting for the given user @@ -1127,6 +1167,109 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return local_media_ids + def _quarantine_local_media_txn( + self, + txn: LoggingTransaction, + hashes: Set[str], + media_ids: Set[str], + quarantined_by: Optional[str], + ) -> int: + """Quarantine and unquarantine local media items. + + Args: + txn (cursor) + hashes: A set of sha256 hashes for any media that should be quarantined + media_ids: A set of media IDs for any media that should be quarantined + quarantined_by: The ID of the user who initiated the quarantine request + If it is `None` media will be removed from quarantine + Returns: + The total number of media items quarantined + """ + total_media_quarantined = 0 + + # Effectively a legacy path, update any media that was explicitly named. + if media_ids: + sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause( + txn.database_engine, "media_id", media_ids + ) + sql = f""" + UPDATE local_media_repository + SET quarantined_by = ? + WHERE {sql_many_clause_sql}""" + + if quarantined_by is not None: + sql += " AND safe_from_quarantine = FALSE" + + txn.execute(sql, [quarantined_by] + sql_many_clause_args) + # Note that a rowcount of -1 can be used to indicate no rows were affected. + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + # Update any media that was identified via hash. + if hashes: + sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause( + txn.database_engine, "sha256", hashes + ) + sql = f""" + UPDATE local_media_repository + SET quarantined_by = ? + WHERE {sql_many_clause_sql}""" + + if quarantined_by is not None: + sql += " AND safe_from_quarantine = FALSE" + + txn.execute(sql, [quarantined_by] + sql_many_clause_args) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + return total_media_quarantined + + def _quarantine_remote_media_txn( + self, + txn: LoggingTransaction, + hashes: Set[str], + media: Set[Tuple[str, str]], + quarantined_by: Optional[str], + ) -> int: + """Quarantine and unquarantine remote items + + Args: + txn (cursor) + hashes: A set of sha256 hashes for any media that should be quarantined + media_ids: A set of tuples (media_origin, media_id) for any media that should be quarantined + quarantined_by: The ID of the user who initiated the quarantine request + If it is `None` media will be removed from quarantine + Returns: + The total number of media items quarantined + """ + total_media_quarantined = 0 + + if media: + sql_in_list_clause, sql_args = make_tuple_in_list_sql_clause( + txn.database_engine, + ("media_origin", "media_id"), + media, + ) + sql = f""" + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE {sql_in_list_clause}""" + + txn.execute(sql, [quarantined_by] + sql_args) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + total_media_quarantined = 0 + if hashes: + sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause( + txn.database_engine, "sha256", hashes + ) + sql = f""" + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE {sql_many_clause_sql}""" + txn.execute(sql, [quarantined_by] + sql_many_clause_args) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + return total_media_quarantined + def _quarantine_media_txn( self, txn: LoggingTransaction, @@ -1146,40 +1289,93 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): Returns: The total number of media items quarantined """ - - # Update all the tables to set the quarantined_by flag - sql = """ - UPDATE local_media_repository - SET quarantined_by = ? - WHERE media_id = ? - """ - - # set quarantine - if quarantined_by is not None: - sql += "AND safe_from_quarantine = FALSE" - txn.executemany( - sql, [(quarantined_by, media_id) for media_id in local_mxcs] + hashes = set() + media_ids = set() + remote_media = set() + + # First, determine the hashes of the media we want to delete. + # We also want the media_ids for any media that lacks a hash. + if local_mxcs: + hash_sql_many_clause_sql, hash_sql_many_clause_args = ( + make_in_list_sql_clause(txn.database_engine, "media_id", local_mxcs) ) - # remove from quarantine - else: - txn.executemany( - sql, [(quarantined_by, media_id) for media_id in local_mxcs] + hash_sql = f"SELECT sha256, media_id FROM local_media_repository WHERE {hash_sql_many_clause_sql}" + if quarantined_by is not None: + hash_sql += " AND safe_from_quarantine = FALSE" + + txn.execute(hash_sql, hash_sql_many_clause_args) + for sha256, media_id in txn: + if sha256: + hashes.add(sha256) + else: + media_ids.add(media_id) + + # Do the same for remote media + if remote_mxcs: + hash_sql_in_list_clause, hash_sql_args = make_tuple_in_list_sql_clause( + txn.database_engine, + ("media_origin", "media_id"), + remote_mxcs, ) - # Note that a rowcount of -1 can be used to indicate no rows were affected. - total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 + hash_sql = f"SELECT sha256, media_origin, media_id FROM remote_media_cache WHERE {hash_sql_in_list_clause}" + txn.execute(hash_sql, hash_sql_args) + for sha256, media_origin, media_id in txn: + if sha256: + hashes.add(sha256) + else: + remote_media.add((media_origin, media_id)) - txn.executemany( - """ - UPDATE remote_media_cache - SET quarantined_by = ? - WHERE media_origin = ? AND media_id = ? - """, - ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs), + count = self._quarantine_local_media_txn(txn, hashes, media_ids, quarantined_by) + count += self._quarantine_remote_media_txn( + txn, hashes, remote_media, quarantined_by ) - total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 - return total_media_quarantined + return count + + async def block_room(self, room_id: str, user_id: str) -> None: + """Marks the room as blocked. + + Can be called multiple times (though we'll only track the last user to + block this room). + + Can be called on a room unknown to this homeserver. + + Args: + room_id: Room to block + user_id: Who blocked it + """ + await self.db_pool.simple_upsert( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={"user_id": user_id}, + desc="block_room", + ) + await self.db_pool.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) + + async def unblock_room(self, room_id: str) -> None: + """Remove the room from blocking list. + + Args: + room_id: Room to unblock + """ + await self.db_pool.simple_delete( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + desc="unblock_room", + ) + await self.db_pool.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) async def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False @@ -1382,6 +1578,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): partial_state_rooms = {row[0] for row in rows} return {room_id: room_id in partial_state_rooms for room_id in room_ids} + @cached(max_entries=10000, iterable=True) + async def get_partial_rooms(self) -> AbstractSet[str]: + """Get any "partial-state" rooms which the user is in. + + This is fast as the set of partially stated rooms at any point across + the whole server is small, and so such a query is fast. This is also + faster than looking up whether a set of room ID's are partially stated + via `is_partial_state_room_batched(...)` because of the sheer amount of + CPU time looking all the rooms up in the cache. + """ + + def _get_partial_rooms_for_user_txn( + txn: LoggingTransaction, + ) -> AbstractSet[str]: + sql = """ + SELECT room_id FROM partial_state_rooms + """ + txn.execute(sql) + return {room_id for (room_id,) in txn} + + return await self.db_pool.runInteraction( + "get_partial_rooms_for_user", _get_partial_rooms_for_user_txn + ) + async def get_join_event_id_and_device_lists_stream_id_for_partial_state( self, room_id: str ) -> Tuple[str, int]: @@ -1562,6 +1782,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): direction: Direction = Direction.BACKWARDS, user_id: Optional[str] = None, room_id: Optional[str] = None, + event_sender_user_id: Optional[str] = None, ) -> Tuple[List[Dict[str, Any]], int]: """Retrieve a paginated list of event reports @@ -1572,6 +1793,8 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): oldest first (forwards) user_id: search for user_id. Ignored if user_id is None room_id: search for room_id. Ignored if room_id is None + event_sender_user_id: search for the sender of the reported event. Ignored if + event_sender_user_id is None Returns: Tuple of: json list of event reports @@ -1591,6 +1814,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): filters.append("er.room_id LIKE ?") args.extend(["%" + room_id + "%"]) + if event_sender_user_id: + filters.append("events.sender = ?") + args.extend([event_sender_user_id]) + if direction == Direction.BACKWARDS: order = "DESC" else: @@ -1606,11 +1833,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): sql = """ SELECT COUNT(*) as total_event_reports FROM event_reports AS er + LEFT JOIN events USING(event_id) JOIN room_stats_state ON room_stats_state.room_id = er.room_id {} - """.format( - where_clause - ) + """.format(where_clause) txn.execute(sql, args) count = cast(Tuple[int], txn.fetchone())[0] @@ -1626,8 +1852,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): room_stats_state.canonical_alias, room_stats_state.name FROM event_reports AS er - LEFT JOIN events - ON events.event_id = er.event_id + LEFT JOIN events USING(event_id) JOIN room_stats_state ON room_stats_state.room_id = er.room_id {where_clause} @@ -2343,6 +2568,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self._invalidate_cache_and_stream( txn, self._get_partial_state_servers_at_join, (room_id,) ) + self._invalidate_all_cache_and_stream(txn, self.get_partial_rooms) async def write_partial_state_rooms_join_event_id( self, @@ -2470,50 +2696,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): ) return next_id - async def block_room(self, room_id: str, user_id: str) -> None: - """Marks the room as blocked. - - Can be called multiple times (though we'll only track the last user to - block this room). - - Can be called on a room unknown to this homeserver. - - Args: - room_id: Room to block - user_id: Who blocked it - """ - await self.db_pool.simple_upsert( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - values={}, - insertion_values={"user_id": user_id}, - desc="block_room", - ) - await self.db_pool.runInteraction( - "block_room_invalidation", - self._invalidate_cache_and_stream, - self.is_room_blocked, - (room_id,), - ) - - async def unblock_room(self, room_id: str) -> None: - """Remove the room from blocking list. - - Args: - room_id: Room to unblock - """ - await self.db_pool.simple_delete( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - desc="unblock_room", - ) - await self.db_pool.runInteraction( - "block_room_invalidation", - self._invalidate_cache_and_stream, - self.is_room_blocked, - (room_id,), - ) - async def clear_partial_state_room(self, room_id: str) -> Optional[int]: """Clears the partial state flag for a room. @@ -2527,7 +2709,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): still contains events with partial state. """ try: - async with self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id: + async with ( + self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id + ): await self.db_pool.runInteraction( "clear_partial_state_room", self._clear_partial_state_room_txn, @@ -2564,6 +2748,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self._invalidate_cache_and_stream( txn, self._get_partial_state_servers_at_join, (room_id,) ) + self._invalidate_all_cache_and_stream(txn, self.get_partial_rooms) DatabasePool.simple_insert_txn( txn,