summary refs log tree commit diff
path: root/synapse/storage/databases/main/room.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/room.py')
-rw-r--r--synapse/storage/databases/main/room.py226
1 files changed, 143 insertions, 83 deletions
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py

index 7d694d852d..c0e837854a 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -13,20 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging from abc import abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) + +import attr from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction -from synapse.storage.databases.main.search import SearchStore +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import IdGenerator from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -38,9 +54,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -RatelimitOverride = collections.namedtuple( - "RatelimitOverride", ("messages_per_second", "burst_count") -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RatelimitOverride: + messages_per_second: int + burst_count: int class RoomSortOrder(Enum): @@ -71,8 +88,13 @@ class RoomSortOrder(Enum): STATE_EVENTS = "state_events" -class RoomWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): +class RoomWorkerStore(CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.config = hs.config @@ -83,7 +105,7 @@ class RoomWorkerStore(SQLBaseStore): room_creator_user_id: str, is_public: bool, room_version: RoomVersion, - ): + ) -> None: """Stores a room. Args: @@ -111,7 +133,7 @@ class RoomWorkerStore(SQLBaseStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - async def get_room(self, room_id: str) -> dict: + async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]: """Retrieve a room. Args: @@ -136,7 +158,9 @@ class RoomWorkerStore(SQLBaseStore): A dict containing the room information, or None if the room is unknown. """ - def get_room_with_stats_txn(txn, room_id): + def get_room_with_stats_txn( + txn: LoggingTransaction, room_id: str + ) -> Optional[Dict[str, Any]]: sql = """ SELECT room_id, state.name, state.canonical_alias, curr.joined_members, curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, @@ -185,7 +209,7 @@ class RoomWorkerStore(SQLBaseStore): ignore_non_federatable: If true filters out non-federatable rooms """ - def _count_public_rooms_txn(txn): + def _count_public_rooms_txn(txn: LoggingTransaction) -> int: query_args = [] if network_tuple: @@ -195,6 +219,7 @@ class RoomWorkerStore(SQLBaseStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ @@ -208,7 +233,7 @@ class RoomWorkerStore(SQLBaseStore): sql = """ SELECT - COALESCE(COUNT(*), 0) + COUNT(*) FROM ( %(published_sql)s ) published @@ -226,7 +251,7 @@ class RoomWorkerStore(SQLBaseStore): } txn.execute(sql, query_args) - return txn.fetchone()[0] + return cast(Tuple[int], txn.fetchone())[0] return await self.db_pool.runInteraction( "count_public_rooms", _count_public_rooms_txn @@ -235,11 +260,11 @@ class RoomWorkerStore(SQLBaseStore): async def get_room_count(self) -> int: """Retrieve the total number of rooms.""" - def f(txn): + def f(txn: LoggingTransaction) -> int: sql = "SELECT count(*) FROM rooms" txn.execute(sql) - row = txn.fetchone() - return row[0] or 0 + row = cast(Tuple[int], txn.fetchone()) + return row[0] return await self.db_pool.runInteraction("get_rooms", f) @@ -251,7 +276,7 @@ class RoomWorkerStore(SQLBaseStore): bounds: Optional[Tuple[int, str]], forwards: bool, ignore_non_federatable: bool = False, - ): + ) -> List[Dict[str, Any]]: """Gets the largest public rooms (where largest is in terms of joined members, as tracked in the statistics table). @@ -272,7 +297,7 @@ class RoomWorkerStore(SQLBaseStore): """ where_clauses = [] - query_args = [] + query_args: List[Union[str, int]] = [] if network_tuple: if network_tuple.appservice_id: @@ -281,6 +306,7 @@ class RoomWorkerStore(SQLBaseStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ @@ -372,7 +398,9 @@ class RoomWorkerStore(SQLBaseStore): LIMIT ? """ - def _get_largest_public_rooms_txn(txn): + def _get_largest_public_rooms_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Any]]: txn.execute(sql, query_args) results = self.db_pool.cursor_to_dict(txn) @@ -435,7 +463,7 @@ class RoomWorkerStore(SQLBaseStore): """ # Filter room names by a string where_statement = "" - search_pattern = [] + search_pattern: List[object] = [] if search_term: where_statement = """ WHERE LOWER(state.name) LIKE ? @@ -543,7 +571,9 @@ class RoomWorkerStore(SQLBaseStore): where_statement, ) - def _get_rooms_paginate_txn(txn): + def _get_rooms_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: # Add the search term into the WHERE clause # and execute the data query txn.execute(info_sql, search_pattern + [limit, start]) @@ -575,7 +605,7 @@ class RoomWorkerStore(SQLBaseStore): # Add the search term into the WHERE clause if present txn.execute(count_sql, search_pattern) - room_count = txn.fetchone() + room_count = cast(Tuple[int], txn.fetchone()) return rooms, room_count[0] return await self.db_pool.runInteraction( @@ -620,7 +650,7 @@ class RoomWorkerStore(SQLBaseStore): burst_count: How many actions that can be performed before being limited. """ - def set_ratelimit_txn(txn): + def set_ratelimit_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_upsert_txn( txn, table="ratelimit_override", @@ -643,7 +673,7 @@ class RoomWorkerStore(SQLBaseStore): user_id: user ID of the user """ - def delete_ratelimit_txn(txn): + def delete_ratelimit_txn(txn: LoggingTransaction) -> None: row = self.db_pool.simple_select_one_txn( txn, table="ratelimit_override", @@ -667,7 +697,7 @@ class RoomWorkerStore(SQLBaseStore): await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn) @cached() - async def get_retention_policy_for_room(self, room_id): + async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]: """Get the retention policy for a given room. If no retention policy has been found for this room, returns a policy defined @@ -676,13 +706,15 @@ class RoomWorkerStore(SQLBaseStore): configuration). Args: - room_id (str): The ID of the room to get the retention policy of. + room_id: The ID of the room to get the retention policy of. Returns: - dict[int, int]: "min_lifetime" and "max_lifetime" for this room. + A dict containing "min_lifetime" and "max_lifetime" for this room. """ - def get_retention_policy_for_room_txn(txn): + def get_retention_policy_for_room_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Optional[int]]]: txn.execute( """ SELECT min_lifetime, max_lifetime FROM room_retention @@ -707,19 +739,23 @@ class RoomWorkerStore(SQLBaseStore): "max_lifetime": self.config.retention.retention_default_max_lifetime, } - row = ret[0] + min_lifetime = ret[0]["min_lifetime"] + max_lifetime = ret[0]["max_lifetime"] # If one of the room's policy's attributes isn't defined, use the matching # attribute from the default policy. # The default values will be None if no default policy has been defined, or if one # of the attributes is missing from the default policy. - if row["min_lifetime"] is None: - row["min_lifetime"] = self.config.retention.retention_default_min_lifetime + if min_lifetime is None: + min_lifetime = self.config.retention.retention_default_min_lifetime - if row["max_lifetime"] is None: - row["max_lifetime"] = self.config.retention.retention_default_max_lifetime + if max_lifetime is None: + max_lifetime = self.config.retention.retention_default_max_lifetime - return row + return { + "min_lifetime": min_lifetime, + "max_lifetime": max_lifetime, + } async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: """Retrieves all the local and remote media MXC URIs in a given room @@ -731,7 +767,9 @@ class RoomWorkerStore(SQLBaseStore): The local and remote media as a lists of the media IDs. """ - def _get_media_mxcs_in_room_txn(txn): + def _get_media_mxcs_in_room_txn( + txn: LoggingTransaction, + ) -> Tuple[List[str], List[str]]: local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) local_media_mxcs = [] remote_media_mxcs = [] @@ -757,7 +795,7 @@ class RoomWorkerStore(SQLBaseStore): logger.info("Quarantining media in room: %s", room_id) - def _quarantine_media_in_room_txn(txn): + def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int: local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) return self._quarantine_media_txn( txn, local_mxcs, remote_mxcs, quarantined_by @@ -767,13 +805,11 @@ class RoomWorkerStore(SQLBaseStore): "quarantine_media_in_room", _quarantine_media_in_room_txn ) - def _get_media_mxcs_in_room_txn(self, txn, room_id): + def _get_media_mxcs_in_room_txn( + self, txn: LoggingTransaction, room_id: str + ) -> Tuple[List[str], List[Tuple[str, str]]]: """Retrieves all the local and remote media MXC URIs in a given room - Args: - txn (cursor) - room_id (str) - Returns: The local and remote media as a lists of tuples where the key is the hostname and the value is the media ID. @@ -841,7 +877,7 @@ class RoomWorkerStore(SQLBaseStore): logger.info("Quarantining media: %s/%s", server_name, media_id) is_local = server_name == self.config.server.server_name - def _quarantine_media_by_id_txn(txn): + def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int: local_mxcs = [media_id] if is_local else [] remote_mxcs = [(server_name, media_id)] if not is_local else [] @@ -863,7 +899,7 @@ class RoomWorkerStore(SQLBaseStore): quarantined_by: The ID of the user who made the quarantine request """ - def _quarantine_media_by_user_txn(txn): + def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int: local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) @@ -871,7 +907,9 @@ class RoomWorkerStore(SQLBaseStore): "quarantine_media_by_user", _quarantine_media_by_user_txn ) - def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True): + def _get_media_ids_by_user_txn( + self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True + ) -> List[str]: """Retrieves local media IDs by a given user Args: @@ -900,7 +938,7 @@ class RoomWorkerStore(SQLBaseStore): def _quarantine_media_txn( self, - txn, + txn: LoggingTransaction, local_mxcs: List[str], remote_mxcs: List[Tuple[str, str]], quarantined_by: Optional[str], @@ -928,12 +966,15 @@ class RoomWorkerStore(SQLBaseStore): # set quarantine if quarantined_by is not None: sql += "AND safe_from_quarantine = ?" - rows = [(quarantined_by, media_id, False) for media_id in local_mxcs] + txn.executemany( + sql, [(quarantined_by, media_id, False) for media_id in local_mxcs] + ) # remove from quarantine else: - rows = [(quarantined_by, media_id) for media_id in local_mxcs] + txn.executemany( + sql, [(quarantined_by, media_id) for media_id in local_mxcs] + ) - txn.executemany(sql, rows) # Note that a rowcount of -1 can be used to indicate no rows were affected. total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 @@ -951,7 +992,7 @@ class RoomWorkerStore(SQLBaseStore): async def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False - ) -> Dict[str, dict]: + ) -> Dict[str, Dict[str, Optional[int]]]: """Retrieves all of the rooms within the given retention range. Optionally includes the rooms which don't have a retention policy. @@ -971,7 +1012,9 @@ class RoomWorkerStore(SQLBaseStore): "min_lifetime" (int|None), and "max_lifetime" (int|None). """ - def get_rooms_for_retention_period_in_range_txn(txn): + def get_rooms_for_retention_period_in_range_txn( + txn: LoggingTransaction, + ) -> Dict[str, Dict[str, Optional[int]]]: range_conditions = [] args = [] @@ -1050,11 +1093,14 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( class RoomBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) - self.config = hs.config - self.db_pool.updates.register_background_update_handler( "insert_room_retention", self._background_insert_retention, @@ -1085,7 +1131,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._background_populate_rooms_creator_column, ) - async def _background_insert_retention(self, progress, batch_size): + async def _background_insert_retention( + self, progress: JsonDict, batch_size: int + ) -> int: """Retrieves a list of all rooms within a range and inserts an entry for each of them into the room_retention table. NULLs the property's columns if missing from the retention event in the room's @@ -1095,7 +1143,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): last_room = progress.get("room_id", "") - def _background_insert_retention_txn(txn): + def _background_insert_retention_txn(txn: LoggingTransaction) -> bool: txn.execute( """ SELECT state.room_id, state.event_id, events.json @@ -1154,15 +1202,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size async def _background_add_rooms_room_version_column( - self, progress: dict, batch_size: int - ): + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to go and add room version information to `rooms` table from `current_state_events` table. """ last_room_id = progress.get("room_id", "") - def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): + def _background_add_rooms_room_version_column_txn( + txn: LoggingTransaction, + ) -> bool: sql = """ SELECT room_id, json FROM current_state_events INNER JOIN event_json USING (room_id, event_id) @@ -1223,7 +1273,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size async def _remove_tombstoned_rooms_from_directory( - self, progress, batch_size + self, progress: JsonDict, batch_size: int ) -> int: """Removes any rooms with tombstone events from the room directory @@ -1233,7 +1283,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): last_room = progress.get("room_id", "") - def _get_rooms(txn): + def _get_rooms(txn: LoggingTransaction) -> List[str]: txn.execute( """ SELECT room_id @@ -1271,7 +1321,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return len(rooms) @abstractmethod - def set_room_is_public(self, room_id, is_public): + def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]: # this will need to be implemented if a background update is performed with # existing (tombstoned, public) rooms in the database. # @@ -1318,7 +1368,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): 32-bit integer field. """ - def process(txn: Cursor) -> int: + def process(txn: LoggingTransaction) -> int: last_room = progress.get("last_room", "") txn.execute( """ @@ -1375,15 +1425,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return 0 async def _background_populate_rooms_creator_column( - self, progress: dict, batch_size: int - ): + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to go and add creator information to `rooms` table from `current_state_events` table. """ last_room_id = progress.get("room_id", "") - def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction): + def _background_populate_rooms_creator_column_txn( + txn: LoggingTransaction, + ) -> bool: sql = """ SELECT room_id, json FROM event_json INNER JOIN rooms AS room USING (room_id) @@ -1434,15 +1486,20 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size -class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): +class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) - self.config = hs.config + self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") async def upsert_room_on_join( self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] - ): + ) -> None: """Ensure that the room is stored in the table Called when we join a room over federation, and overwrites any room version @@ -1488,7 +1545,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion - ): + ) -> None: """ When we receive an invite or any other event over federation that may relate to a room we are not in, store the version of the room if we don't already know the room version. @@ -1528,8 +1585,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.hs.get_notifier().on_new_replication_data() async def set_room_is_public_appservice( - self, room_id, appservice_id, network_id, is_public - ): + self, room_id: str, appservice_id: str, network_id: str, is_public: bool + ) -> None: """Edit the appservice/network specific public room list. Each appservice can have a number of published room lists associated @@ -1538,11 +1595,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): network. Args: - room_id (str) - appservice_id (str) - network_id (str) - is_public (bool): Whether to publish or unpublish the room from the - list. + room_id + appservice_id + network_id + is_public: Whether to publish or unpublish the room from the list. """ if is_public: @@ -1607,7 +1663,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): event_report: json list of information from event report """ - def _get_event_report_txn(txn, report_id): + def _get_event_report_txn( + txn: LoggingTransaction, report_id: int + ) -> Optional[Dict[str, Any]]: sql = """ SELECT @@ -1679,9 +1737,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): count: total number of event reports matching the filter criteria """ - def _get_event_reports_paginate_txn(txn): + def _get_event_reports_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: filters = [] - args = [] + args: List[object] = [] if user_id: filters.append("er.user_id LIKE ?") @@ -1705,7 +1765,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): where_clause ) txn.execute(sql, args) - count = txn.fetchone()[0] + count = cast(Tuple[int], txn.fetchone())[0] sql = """ SELECT