diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 80a4bf95f2..347dbbba6b 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -51,11 +51,15 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream
-from synapse.storage._base import db_to_json, make_in_list_sql_clause
+from synapse.storage._base import (
+ db_to_json,
+ make_in_list_sql_clause,
+)
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
+ make_tuple_in_list_sql_clause,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.types import Cursor
@@ -73,6 +77,8 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RatelimitOverride:
+ # n.b. elsewhere in Synapse messages_per_second is represented as a float, but it is
+ # an integer in the database
messages_per_second: int
burst_count: int
@@ -604,6 +610,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
search_term: Optional[str],
public_rooms: Optional[bool],
empty_rooms: Optional[bool],
+ emma_include_tombstone: bool = False,
) -> Tuple[List[Dict[str, Any]], int]:
"""Function to retrieve a paginated list of rooms as json.
@@ -623,6 +630,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
If true, empty rooms are queried.
if false, empty rooms are excluded from the query. When it is
none (the default), both empty rooms and none-empty rooms are queried.
+ emma_include_tombstone: If true, include tombstone events in the results.
Returns:
A list of room dicts and an integer representing the total number of
rooms that exist given this query
@@ -791,11 +799,43 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
room_count = cast(Tuple[int], txn.fetchone())
return rooms, room_count[0]
- return await self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
"get_rooms_paginate",
_get_rooms_paginate_txn,
)
+ if emma_include_tombstone:
+ room_id_sql, room_id_args = make_in_list_sql_clause(
+ self.database_engine, "cse.room_id", [r["room_id"] for r in result[0]]
+ )
+
+ tombstone_sql = """
+ SELECT cse.room_id, cse.event_id, ej.json
+ FROM current_state_events cse
+ JOIN event_json ej USING (event_id)
+ WHERE cse.type = 'm.room.tombstone'
+ AND {room_id_sql}
+ """.format(
+ room_id_sql=room_id_sql
+ )
+
+ def _get_tombstones_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, JsonDict]:
+ txn.execute(tombstone_sql, room_id_args)
+ for room_id, event_id, json in txn:
+ for result_room in result[0]:
+ if result_room["room_id"] == room_id:
+ result_room["gay.rory.synapse_admin_extensions.tombstone"] = db_to_json(json)
+ break
+ return result[0], result[1]
+
+ result = await self.db_pool.runInteraction(
+ "get_rooms_tombstones", _get_tombstones_txn,
+ )
+
+ return result
+
@cached(max_entries=10000)
async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]:
"""Check if there are any overrides for ratelimiting for the given user
@@ -1127,6 +1167,109 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
return local_media_ids
+ def _quarantine_local_media_txn(
+ self,
+ txn: LoggingTransaction,
+ hashes: Set[str],
+ media_ids: Set[str],
+ quarantined_by: Optional[str],
+ ) -> int:
+ """Quarantine and unquarantine local media items.
+
+ Args:
+ txn (cursor)
+ hashes: A set of sha256 hashes for any media that should be quarantined
+ media_ids: A set of media IDs for any media that should be quarantined
+ quarantined_by: The ID of the user who initiated the quarantine request
+ If it is `None` media will be removed from quarantine
+ Returns:
+ The total number of media items quarantined
+ """
+ total_media_quarantined = 0
+
+ # Effectively a legacy path, update any media that was explicitly named.
+ if media_ids:
+ sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
+ txn.database_engine, "media_id", media_ids
+ )
+ sql = f"""
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE {sql_many_clause_sql}"""
+
+ if quarantined_by is not None:
+ sql += " AND safe_from_quarantine = FALSE"
+
+ txn.execute(sql, [quarantined_by] + sql_many_clause_args)
+ # Note that a rowcount of -1 can be used to indicate no rows were affected.
+ total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
+
+ # Update any media that was identified via hash.
+ if hashes:
+ sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
+ txn.database_engine, "sha256", hashes
+ )
+ sql = f"""
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE {sql_many_clause_sql}"""
+
+ if quarantined_by is not None:
+ sql += " AND safe_from_quarantine = FALSE"
+
+ txn.execute(sql, [quarantined_by] + sql_many_clause_args)
+ total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
+
+ return total_media_quarantined
+
+ def _quarantine_remote_media_txn(
+ self,
+ txn: LoggingTransaction,
+ hashes: Set[str],
+ media: Set[Tuple[str, str]],
+ quarantined_by: Optional[str],
+ ) -> int:
+ """Quarantine and unquarantine remote items
+
+ Args:
+ txn (cursor)
+ hashes: A set of sha256 hashes for any media that should be quarantined
+ media_ids: A set of tuples (media_origin, media_id) for any media that should be quarantined
+ quarantined_by: The ID of the user who initiated the quarantine request
+ If it is `None` media will be removed from quarantine
+ Returns:
+ The total number of media items quarantined
+ """
+ total_media_quarantined = 0
+
+ if media:
+ sql_in_list_clause, sql_args = make_tuple_in_list_sql_clause(
+ txn.database_engine,
+ ("media_origin", "media_id"),
+ media,
+ )
+ sql = f"""
+ UPDATE remote_media_cache
+ SET quarantined_by = ?
+ WHERE {sql_in_list_clause}"""
+
+ txn.execute(sql, [quarantined_by] + sql_args)
+ total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
+
+ total_media_quarantined = 0
+ if hashes:
+ sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause(
+ txn.database_engine, "sha256", hashes
+ )
+ sql = f"""
+ UPDATE remote_media_cache
+ SET quarantined_by = ?
+ WHERE {sql_many_clause_sql}"""
+ txn.execute(sql, [quarantined_by] + sql_many_clause_args)
+ total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
+
+ return total_media_quarantined
+
def _quarantine_media_txn(
self,
txn: LoggingTransaction,
@@ -1146,40 +1289,93 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
Returns:
The total number of media items quarantined
"""
-
- # Update all the tables to set the quarantined_by flag
- sql = """
- UPDATE local_media_repository
- SET quarantined_by = ?
- WHERE media_id = ?
- """
-
- # set quarantine
- if quarantined_by is not None:
- sql += "AND safe_from_quarantine = FALSE"
- txn.executemany(
- sql, [(quarantined_by, media_id) for media_id in local_mxcs]
+ hashes = set()
+ media_ids = set()
+ remote_media = set()
+
+ # First, determine the hashes of the media we want to delete.
+ # We also want the media_ids for any media that lacks a hash.
+ if local_mxcs:
+ hash_sql_many_clause_sql, hash_sql_many_clause_args = (
+ make_in_list_sql_clause(txn.database_engine, "media_id", local_mxcs)
)
- # remove from quarantine
- else:
- txn.executemany(
- sql, [(quarantined_by, media_id) for media_id in local_mxcs]
+ hash_sql = f"SELECT sha256, media_id FROM local_media_repository WHERE {hash_sql_many_clause_sql}"
+ if quarantined_by is not None:
+ hash_sql += " AND safe_from_quarantine = FALSE"
+
+ txn.execute(hash_sql, hash_sql_many_clause_args)
+ for sha256, media_id in txn:
+ if sha256:
+ hashes.add(sha256)
+ else:
+ media_ids.add(media_id)
+
+ # Do the same for remote media
+ if remote_mxcs:
+ hash_sql_in_list_clause, hash_sql_args = make_tuple_in_list_sql_clause(
+ txn.database_engine,
+ ("media_origin", "media_id"),
+ remote_mxcs,
)
- # Note that a rowcount of -1 can be used to indicate no rows were affected.
- total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
+ hash_sql = f"SELECT sha256, media_origin, media_id FROM remote_media_cache WHERE {hash_sql_in_list_clause}"
+ txn.execute(hash_sql, hash_sql_args)
+ for sha256, media_origin, media_id in txn:
+ if sha256:
+ hashes.add(sha256)
+ else:
+ remote_media.add((media_origin, media_id))
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin = ? AND media_id = ?
- """,
- ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
+ count = self._quarantine_local_media_txn(txn, hashes, media_ids, quarantined_by)
+ count += self._quarantine_remote_media_txn(
+ txn, hashes, remote_media, quarantined_by
)
- total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
- return total_media_quarantined
+ return count
+
+ async def block_room(self, room_id: str, user_id: str) -> None:
+ """Marks the room as blocked.
+
+ Can be called multiple times (though we'll only track the last user to
+ block this room).
+
+ Can be called on a room unknown to this homeserver.
+
+ Args:
+ room_id: Room to block
+ user_id: Who blocked it
+ """
+ await self.db_pool.simple_upsert(
+ table="blocked_rooms",
+ keyvalues={"room_id": room_id},
+ values={},
+ insertion_values={"user_id": user_id},
+ desc="block_room",
+ )
+ await self.db_pool.runInteraction(
+ "block_room_invalidation",
+ self._invalidate_cache_and_stream,
+ self.is_room_blocked,
+ (room_id,),
+ )
+
+ async def unblock_room(self, room_id: str) -> None:
+ """Remove the room from blocking list.
+
+ Args:
+ room_id: Room to unblock
+ """
+ await self.db_pool.simple_delete(
+ table="blocked_rooms",
+ keyvalues={"room_id": room_id},
+ desc="unblock_room",
+ )
+ await self.db_pool.runInteraction(
+ "block_room_invalidation",
+ self._invalidate_cache_and_stream,
+ self.is_room_blocked,
+ (room_id,),
+ )
async def get_rooms_for_retention_period_in_range(
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
@@ -1382,6 +1578,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
partial_state_rooms = {row[0] for row in rows}
return {room_id: room_id in partial_state_rooms for room_id in room_ids}
+ @cached(max_entries=10000, iterable=True)
+ async def get_partial_rooms(self) -> AbstractSet[str]:
+ """Get any "partial-state" rooms which the user is in.
+
+ This is fast as the set of partially stated rooms at any point across
+ the whole server is small, and so such a query is fast. This is also
+ faster than looking up whether a set of room ID's are partially stated
+ via `is_partial_state_room_batched(...)` because of the sheer amount of
+ CPU time looking all the rooms up in the cache.
+ """
+
+ def _get_partial_rooms_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> AbstractSet[str]:
+ sql = """
+ SELECT room_id FROM partial_state_rooms
+ """
+ txn.execute(sql)
+ return {room_id for (room_id,) in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_partial_rooms_for_user", _get_partial_rooms_for_user_txn
+ )
+
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
self, room_id: str
) -> Tuple[str, int]:
@@ -1562,6 +1782,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
direction: Direction = Direction.BACKWARDS,
user_id: Optional[str] = None,
room_id: Optional[str] = None,
+ event_sender_user_id: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""Retrieve a paginated list of event reports
@@ -1572,6 +1793,8 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
oldest first (forwards)
user_id: search for user_id. Ignored if user_id is None
room_id: search for room_id. Ignored if room_id is None
+ event_sender_user_id: search for the sender of the reported event. Ignored if
+ event_sender_user_id is None
Returns:
Tuple of:
json list of event reports
@@ -1591,6 +1814,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
filters.append("er.room_id LIKE ?")
args.extend(["%" + room_id + "%"])
+ if event_sender_user_id:
+ filters.append("events.sender = ?")
+ args.extend([event_sender_user_id])
+
if direction == Direction.BACKWARDS:
order = "DESC"
else:
@@ -1606,11 +1833,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
sql = """
SELECT COUNT(*) as total_event_reports
FROM event_reports AS er
+ LEFT JOIN events USING(event_id)
JOIN room_stats_state ON room_stats_state.room_id = er.room_id
{}
- """.format(
- where_clause
- )
+ """.format(where_clause)
txn.execute(sql, args)
count = cast(Tuple[int], txn.fetchone())[0]
@@ -1626,8 +1852,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
room_stats_state.canonical_alias,
room_stats_state.name
FROM event_reports AS er
- LEFT JOIN events
- ON events.event_id = er.event_id
+ LEFT JOIN events USING(event_id)
JOIN room_stats_state
ON room_stats_state.room_id = er.room_id
{where_clause}
@@ -2343,6 +2568,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self._invalidate_cache_and_stream(
txn, self._get_partial_state_servers_at_join, (room_id,)
)
+ self._invalidate_all_cache_and_stream(txn, self.get_partial_rooms)
async def write_partial_state_rooms_join_event_id(
self,
@@ -2470,50 +2696,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
)
return next_id
- async def block_room(self, room_id: str, user_id: str) -> None:
- """Marks the room as blocked.
-
- Can be called multiple times (though we'll only track the last user to
- block this room).
-
- Can be called on a room unknown to this homeserver.
-
- Args:
- room_id: Room to block
- user_id: Who blocked it
- """
- await self.db_pool.simple_upsert(
- table="blocked_rooms",
- keyvalues={"room_id": room_id},
- values={},
- insertion_values={"user_id": user_id},
- desc="block_room",
- )
- await self.db_pool.runInteraction(
- "block_room_invalidation",
- self._invalidate_cache_and_stream,
- self.is_room_blocked,
- (room_id,),
- )
-
- async def unblock_room(self, room_id: str) -> None:
- """Remove the room from blocking list.
-
- Args:
- room_id: Room to unblock
- """
- await self.db_pool.simple_delete(
- table="blocked_rooms",
- keyvalues={"room_id": room_id},
- desc="unblock_room",
- )
- await self.db_pool.runInteraction(
- "block_room_invalidation",
- self._invalidate_cache_and_stream,
- self.is_room_blocked,
- (room_id,),
- )
-
async def clear_partial_state_room(self, room_id: str) -> Optional[int]:
"""Clears the partial state flag for a room.
@@ -2527,7 +2709,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
still contains events with partial state.
"""
try:
- async with self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id:
+ async with (
+ self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id
+ ):
await self.db_pool.runInteraction(
"clear_partial_state_room",
self._clear_partial_state_room_txn,
@@ -2564,6 +2748,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self._invalidate_cache_and_stream(
txn, self._get_partial_state_servers_at_join, (room_id,)
)
+ self._invalidate_all_cache_and_stream(txn, self.get_partial_rooms)
DatabasePool.simple_insert_txn(
txn,
|