summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2023-11-03 07:45:38 -0400
committerPatrick Cloke <patrickc@matrix.org>2023-11-03 07:45:38 -0400
commit671266b5a930674a26b25df8897957b05904dae9 (patch)
tree74bd22f6bc1cb09b3bbca7c461d39c1d1880105b /synapse/storage/databases/main
parentMerge remote-tracking branch 'origin/develop' into matrix-org-hotfixes (diff)
parentSimplify event persistence code (#16584) (diff)
downloadsynapse-671266b5a930674a26b25df8897957b05904dae9.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/__init__.py52
-rw-r--r--synapse/storage/databases/main/devices.py55
-rw-r--r--synapse/storage/databases/main/events.py400
-rw-r--r--synapse/storage/databases/main/media_repository.py48
-rw-r--r--synapse/storage/databases/main/registration.py42
-rw-r--r--synapse/storage/databases/main/room.py82
6 files changed, 421 insertions, 258 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py

index 840d725114..89f4077351 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -17,6 +17,8 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast +import attr + from synapse.api.constants import Direction from synapse.config.homeserver import HomeServerConfig from synapse.storage._base import make_in_list_sql_clause @@ -28,7 +30,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import get_domain_from_id from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore @@ -82,6 +84,25 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UserPaginateResponse: + """This is very similar to UserInfo, but not quite the same.""" + + name: str + user_type: Optional[str] + is_guest: bool + admin: bool + deactivated: bool + shadow_banned: bool + displayname: Optional[str] + avatar_url: Optional[str] + creation_ts: Optional[int] + approved: bool + erased: bool + last_seen_ts: int + locked: bool + + class DataStore( EventsBackgroundUpdatesStore, ExperimentalFeaturesStore, @@ -156,7 +177,7 @@ class DataStore( approved: bool = True, not_user_types: Optional[List[str]] = None, locked: bool = False, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[UserPaginateResponse], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. @@ -182,7 +203,7 @@ class DataStore( def get_users_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[UserPaginateResponse], int]: filters = [] args: list = [] @@ -282,13 +303,24 @@ class DataStore( """ args += [limit, start] txn.execute(sql, args) - users = self.db_pool.cursor_to_dict(txn) - - # some of those boolean values are returned as integers when we're on SQLite - columns_to_boolify = ["erased"] - for user in users: - for column in columns_to_boolify: - user[column] = bool(user[column]) + users = [ + UserPaginateResponse( + name=row[0], + user_type=row[1], + is_guest=bool(row[2]), + admin=bool(row[3]), + deactivated=bool(row[4]), + shadow_banned=bool(row[5]), + displayname=row[6], + avatar_url=row[7], + creation_ts=row[8], + approved=bool(row[9]), + erased=bool(row[10]), + last_seen_ts=row[11], + locked=bool(row[12]), + ) + for row in txn + ] return users, count diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index ae0536fbaf..303ef6ea27 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -1622,7 +1622,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): # # For each duplicate, we delete all the existing rows and put one back. - KEY_COLS = ["stream_id", "destination", "user_id", "device_id"] last_row = progress.get( "last_row", {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, @@ -1630,44 +1629,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): def _txn(txn: LoggingTransaction) -> int: clause, args = make_tuple_comparison_clause( - [(x, last_row[x]) for x in KEY_COLS] + [ + ("stream_id", last_row["stream_id"]), + ("destination", last_row["destination"]), + ("user_id", last_row["user_id"]), + ("device_id", last_row["device_id"]), + ] ) - sql = """ + sql = f""" SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts FROM device_lists_outbound_pokes - WHERE %s - GROUP BY %s + WHERE {clause} + GROUP BY stream_id, destination, user_id, device_id HAVING count(*) > 1 - ORDER BY %s + ORDER BY stream_id, destination, user_id, device_id LIMIT ? - """ % ( - clause, # WHERE - ",".join(KEY_COLS), # GROUP BY - ",".join(KEY_COLS), # ORDER BY - ) + """ txn.execute(sql, args + [batch_size]) - rows = self.db_pool.cursor_to_dict(txn) + rows = txn.fetchall() - row = None - for row in rows: + stream_id, destination, user_id, device_id = None, None, None, None + for stream_id, destination, user_id, device_id, _ in rows: self.db_pool.simple_delete_txn( txn, "device_lists_outbound_pokes", - {x: row[x] for x in KEY_COLS}, + { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + }, ) - row["sent"] = False self.db_pool.simple_insert_txn( txn, "device_lists_outbound_pokes", - row, + { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + "sent": False, + }, ) - if row: + if rows: self.db_pool.updates._background_update_progress_txn( txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, - {"last_row": row}, + { + "last_row": { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + } + }, ) return len(rows) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3c1492e3ad..647ba182f6 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -79,7 +79,7 @@ class DeltaState: Attributes: to_delete: List of type/state_keys to delete from current state to_insert: Map of state to upsert into current state - no_longer_in_room: The server is not longer in the room, so the room + no_longer_in_room: The server is no longer in the room, so the room should e.g. be removed from `current_state_events` table. """ @@ -131,22 +131,25 @@ class PersistEventsStore: @trace async def _persist_events_and_state_updates( self, + room_id: str, events_and_contexts: List[Tuple[EventBase, EventContext]], *, - state_delta_for_room: Dict[str, DeltaState], - new_forward_extremities: Dict[str, Set[str]], + state_delta_for_room: Optional[DeltaState], + new_forward_extremities: Optional[Set[str]], use_negative_stream_ordering: bool = False, inhibit_local_membership_updates: bool = False, ) -> None: """Persist a set of events alongside updates to the current state and - forward extremities tables. + forward extremities tables. + + Assumes that we are only persisting events for one room at a time. Args: + room_id: events_and_contexts: - state_delta_for_room: Map from room_id to the delta to apply to - room state - new_forward_extremities: Map from room_id to set of event IDs - that are the new forward extremities of the room. + state_delta_for_room: The delta to apply to the room state + new_forward_extremities: A set of event IDs that are the new forward + extremities of the room. use_negative_stream_ordering: Whether to start stream_ordering on the negative side and decrement. This should be set as True for backfilled events because backfilled events get a negative @@ -196,6 +199,7 @@ class PersistEventsStore: await self.db_pool.runInteraction( "persist_events", self._persist_events_txn, + room_id=room_id, events_and_contexts=events_and_contexts, inhibit_local_membership_updates=inhibit_local_membership_updates, state_delta_for_room=state_delta_for_room, @@ -221,9 +225,9 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() - for room_id, latest_event_ids in new_forward_extremities.items(): + if new_forward_extremities: self.store.get_latest_event_ids_in_room.prefill( - (room_id,), frozenset(latest_event_ids) + (room_id,), frozenset(new_forward_extremities) ) async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: @@ -336,10 +340,11 @@ class PersistEventsStore: self, txn: LoggingTransaction, *, + room_id: str, events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool, - state_delta_for_room: Dict[str, DeltaState], - new_forward_extremities: Dict[str, Set[str]], + state_delta_for_room: Optional[DeltaState], + new_forward_extremities: Optional[Set[str]], ) -> None: """Insert some number of room events into the necessary database tables. @@ -347,8 +352,11 @@ class PersistEventsStore: and the rejections table. Things reading from those table will need to check whether the event was rejected. + Assumes that we are only persisting events for one room at a time. + Args: txn + room_id: The room the events are from events_and_contexts: events to persist inhibit_local_membership_updates: Stop the local_current_membership from being updated by these events. This should be set to True @@ -357,10 +365,9 @@ class PersistEventsStore: delete_existing True to purge existing table rows for the events from the database. This is useful when retrying due to IntegrityError. - state_delta_for_room: The current-state delta for each room. - new_forward_extremities: The new forward extremities for each room. - For each room, a list of the event ids which are the forward - extremities. + state_delta_for_room: The current-state delta for the room. + new_forward_extremities: The new forward extremities for the room: + a set of the event ids which are the forward extremities. Raises: PartialStateConflictError: if attempting to persist a partial state event in @@ -376,14 +383,13 @@ class PersistEventsStore: # # Annoyingly SQLite doesn't support row level locking. if isinstance(self.database_engine, PostgresEngine): - for room_id in {e.room_id for e, _ in events_and_contexts}: - txn.execute( - "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE", - (room_id,), - ) - row = txn.fetchone() - if row is None: - raise Exception(f"Room does not exist {room_id}") + txn.execute( + "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE", + (room_id,), + ) + row = txn.fetchone() + if row is None: + raise Exception(f"Room does not exist {room_id}") # stream orderings should have been assigned by now assert min_stream_order @@ -419,7 +425,9 @@ class PersistEventsStore: events_and_contexts ) - self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts) + self._update_room_depths_txn( + txn, room_id, events_and_contexts=events_and_contexts + ) # _update_outliers_txn filters out any events which have already been # persisted, and returns the filtered list. @@ -432,11 +440,13 @@ class PersistEventsStore: self._store_event_txn(txn, events_and_contexts=events_and_contexts) - self._update_forward_extremities_txn( - txn, - new_forward_extremities=new_forward_extremities, - max_stream_order=max_stream_order, - ) + if new_forward_extremities: + self._update_forward_extremities_txn( + txn, + room_id, + new_forward_extremities=new_forward_extremities, + max_stream_order=max_stream_order, + ) self._persist_transaction_ids_txn(txn, events_and_contexts) @@ -464,7 +474,10 @@ class PersistEventsStore: # We call this last as it assumes we've inserted the events into # room_memberships, where applicable. # NB: This function invalidates all state related caches - self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + if state_delta_for_room: + self._update_current_state_txn( + txn, room_id, state_delta_for_room, min_stream_order + ) def _persist_event_auth_chain_txn( self, @@ -1026,74 +1039,75 @@ class PersistEventsStore: await self.db_pool.runInteraction( "update_current_state", self._update_current_state_txn, - state_delta_by_room={room_id: state_delta}, + room_id, + delta_state=state_delta, stream_id=stream_ordering, ) def _update_current_state_txn( self, txn: LoggingTransaction, - state_delta_by_room: Dict[str, DeltaState], + room_id: str, + delta_state: DeltaState, stream_id: int, ) -> None: - for room_id, delta_state in state_delta_by_room.items(): - to_delete = delta_state.to_delete - to_insert = delta_state.to_insert - - # Figure out the changes of membership to invalidate the - # `get_rooms_for_user` cache. - # We find out which membership events we may have deleted - # and which we have added, then we invalidate the caches for all - # those users. - members_changed = { - state_key - for ev_type, state_key in itertools.chain(to_delete, to_insert) - if ev_type == EventTypes.Member - } + to_delete = delta_state.to_delete + to_insert = delta_state.to_insert + + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invalidate the caches for all + # those users. + members_changed = { + state_key + for ev_type, state_key in itertools.chain(to_delete, to_insert) + if ev_type == EventTypes.Member + } - if delta_state.no_longer_in_room: - # Server is no longer in the room so we delete the room from - # current_state_events, being careful we've already updated the - # rooms.room_version column (which gets populated in a - # background task). - self._upsert_room_version_txn(txn, room_id) + if delta_state.no_longer_in_room: + # Server is no longer in the room so we delete the room from + # current_state_events, being careful we've already updated the + # rooms.room_version column (which gets populated in a + # background task). + self._upsert_room_version_txn(txn, room_id) - # Before deleting we populate the current_state_delta_stream - # so that async background tasks get told what happened. - sql = """ + # Before deleting we populate the current_state_delta_stream + # so that async background tasks get told what happened. + sql = """ INSERT INTO current_state_delta_stream (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id) SELECT ?, ?, room_id, type, state_key, null, event_id FROM current_state_events WHERE room_id = ? """ - txn.execute(sql, (stream_id, self._instance_name, room_id)) + txn.execute(sql, (stream_id, self._instance_name, room_id)) - # We also want to invalidate the membership caches for users - # that were in the room. - users_in_room = self.store.get_users_in_room_txn(txn, room_id) - members_changed.update(users_in_room) + # We also want to invalidate the membership caches for users + # that were in the room. + users_in_room = self.store.get_users_in_room_txn(txn, room_id) + members_changed.update(users_in_room) - self.db_pool.simple_delete_txn( - txn, - table="current_state_events", - keyvalues={"room_id": room_id}, - ) - else: - # We're still in the room, so we update the current state as normal. + self.db_pool.simple_delete_txn( + txn, + table="current_state_events", + keyvalues={"room_id": room_id}, + ) + else: + # We're still in the room, so we update the current state as normal. - # First we add entries to the current_state_delta_stream. We - # do this before updating the current_state_events table so - # that we can use it to calculate the `prev_event_id`. (This - # allows us to not have to pull out the existing state - # unnecessarily). - # - # The stream_id for the update is chosen to be the minimum of the stream_ids - # for the batch of the events that we are persisting; that means we do not - # end up in a situation where workers see events before the - # current_state_delta updates. - # - sql = """ + # First we add entries to the current_state_delta_stream. We + # do this before updating the current_state_events table so + # that we can use it to calculate the `prev_event_id`. (This + # allows us to not have to pull out the existing state + # unnecessarily). + # + # The stream_id for the update is chosen to be the minimum of the stream_ids + # for the batch of the events that we are persisting; that means we do not + # end up in a situation where workers see events before the + # current_state_delta updates. + # + sql = """ INSERT INTO current_state_delta_stream (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id) SELECT ?, ?, ?, ?, ?, ?, ( @@ -1101,39 +1115,39 @@ class PersistEventsStore: WHERE room_id = ? AND type = ? AND state_key = ? ) """ - txn.execute_batch( - sql, + txn.execute_batch( + sql, + ( ( - ( - stream_id, - self._instance_name, - room_id, - etype, - state_key, - to_insert.get((etype, state_key)), - room_id, - etype, - state_key, - ) - for etype, state_key in itertools.chain(to_delete, to_insert) - ), - ) - # Now we actually update the current_state_events table + stream_id, + self._instance_name, + room_id, + etype, + state_key, + to_insert.get((etype, state_key)), + room_id, + etype, + state_key, + ) + for etype, state_key in itertools.chain(to_delete, to_insert) + ), + ) + # Now we actually update the current_state_events table - txn.execute_batch( - "DELETE FROM current_state_events" - " WHERE room_id = ? AND type = ? AND state_key = ?", - ( - (room_id, etype, state_key) - for etype, state_key in itertools.chain(to_delete, to_insert) - ), - ) + txn.execute_batch( + "DELETE FROM current_state_events" + " WHERE room_id = ? AND type = ? AND state_key = ?", + ( + (room_id, etype, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + ), + ) - # We include the membership in the current state table, hence we do - # a lookup when we insert. This assumes that all events have already - # been inserted into room_memberships. - txn.execute_batch( - """INSERT INTO current_state_events + # We include the membership in the current state table, hence we do + # a lookup when we insert. This assumes that all events have already + # been inserted into room_memberships. + txn.execute_batch( + """INSERT INTO current_state_events (room_id, type, state_key, event_id, membership, event_stream_ordering) VALUES ( ?, ?, ?, ?, @@ -1141,34 +1155,34 @@ class PersistEventsStore: (SELECT stream_ordering FROM events WHERE event_id = ?) ) """, - [ - (room_id, key[0], key[1], ev_id, ev_id, ev_id) - for key, ev_id in to_insert.items() - ], - ) + [ + (room_id, key[0], key[1], ev_id, ev_id, ev_id) + for key, ev_id in to_insert.items() + ], + ) - # We now update `local_current_membership`. We do this regardless - # of whether we're still in the room or not to handle the case where - # e.g. we just got banned (where we need to record that fact here). - - # Note: Do we really want to delete rows here (that we do not - # subsequently reinsert below)? While technically correct it means - # we have no record of the fact the user *was* a member of the - # room but got, say, state reset out of it. - if to_delete or to_insert: - txn.execute_batch( - "DELETE FROM local_current_membership" - " WHERE room_id = ? AND user_id = ?", - ( - (room_id, state_key) - for etype, state_key in itertools.chain(to_delete, to_insert) - if etype == EventTypes.Member and self.is_mine_id(state_key) - ), - ) + # We now update `local_current_membership`. We do this regardless + # of whether we're still in the room or not to handle the case where + # e.g. we just got banned (where we need to record that fact here). - if to_insert: - txn.execute_batch( - """INSERT INTO local_current_membership + # Note: Do we really want to delete rows here (that we do not + # subsequently reinsert below)? While technically correct it means + # we have no record of the fact the user *was* a member of the + # room but got, say, state reset out of it. + if to_delete or to_insert: + txn.execute_batch( + "DELETE FROM local_current_membership" + " WHERE room_id = ? AND user_id = ?", + ( + (room_id, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + if etype == EventTypes.Member and self.is_mine_id(state_key) + ), + ) + + if to_insert: + txn.execute_batch( + """INSERT INTO local_current_membership (room_id, user_id, event_id, membership, event_stream_ordering) VALUES ( ?, ?, ?, @@ -1176,29 +1190,27 @@ class PersistEventsStore: (SELECT stream_ordering FROM events WHERE event_id = ?) ) """, - [ - (room_id, key[1], ev_id, ev_id, ev_id) - for key, ev_id in to_insert.items() - if key[0] == EventTypes.Member and self.is_mine_id(key[1]) - ], - ) - - txn.call_after( - self.store._curr_state_delta_stream_cache.entity_has_changed, - room_id, - stream_id, + [ + (room_id, key[1], ev_id, ev_id, ev_id) + for key, ev_id in to_insert.items() + if key[0] == EventTypes.Member and self.is_mine_id(key[1]) + ], ) - # Invalidate the various caches - self.store._invalidate_state_caches_and_stream( - txn, room_id, members_changed - ) + txn.call_after( + self.store._curr_state_delta_stream_cache.entity_has_changed, + room_id, + stream_id, + ) - # Check if any of the remote membership changes requires us to - # unsubscribe from their device lists. - self.store.handle_potentially_left_users_txn( - txn, {m for m in members_changed if not self.hs.is_mine_id(m)} - ) + # Invalidate the various caches + self.store._invalidate_state_caches_and_stream(txn, room_id, members_changed) + + # Check if any of the remote membership changes requires us to + # unsubscribe from their device lists. + self.store.handle_potentially_left_users_txn( + txn, {m for m in members_changed if not self.hs.is_mine_id(m)} + ) def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None: """Update the room version in the database based off current state @@ -1232,23 +1244,19 @@ class PersistEventsStore: def _update_forward_extremities_txn( self, txn: LoggingTransaction, - new_forward_extremities: Dict[str, Set[str]], + room_id: str, + new_forward_extremities: Set[str], max_stream_order: int, ) -> None: - for room_id in new_forward_extremities.keys(): - self.db_pool.simple_delete_txn( - txn, table="event_forward_extremities", keyvalues={"room_id": room_id} - ) + self.db_pool.simple_delete_txn( + txn, table="event_forward_extremities", keyvalues={"room_id": room_id} + ) self.db_pool.simple_insert_many_txn( txn, table="event_forward_extremities", keys=("event_id", "room_id"), - values=[ - (ev_id, room_id) - for room_id, new_extrem in new_forward_extremities.items() - for ev_id in new_extrem - ], + values=[(ev_id, room_id) for ev_id in new_forward_extremities], ) # We now insert into stream_ordering_to_exterm a mapping from room_id, # new stream_ordering to new forward extremeties in the room. @@ -1260,8 +1268,7 @@ class PersistEventsStore: keys=("room_id", "event_id", "stream_ordering"), values=[ (room_id, event_id, max_stream_order) - for room_id, new_extrem in new_forward_extremities.items() - for event_id in new_extrem + for event_id in new_forward_extremities ], ) @@ -1298,36 +1305,45 @@ class PersistEventsStore: def _update_room_depths_txn( self, txn: LoggingTransaction, + room_id: str, events_and_contexts: List[Tuple[EventBase, EventContext]], ) -> None: """Update min_depth for each room Args: txn: db connection + room_id: The room ID events_and_contexts: events we are persisting """ - depth_updates: Dict[str, int] = {} + stream_ordering: Optional[int] = None + depth_update = 0 for event, context in events_and_contexts: - # Then update the `stream_ordering` position to mark the latest - # event as the front of the room. This should not be done for - # backfilled events because backfilled events have negative - # stream_ordering and happened in the past so we know that we don't - # need to update the stream_ordering tip/front for the room. + # Don't update the stream ordering for backfilled events because + # backfilled events have negative stream_ordering and happened in the + # past, so we know that we don't need to update the stream_ordering + # tip/front for the room. assert event.internal_metadata.stream_ordering is not None if event.internal_metadata.stream_ordering >= 0: - txn.call_after( - self.store._events_stream_cache.entity_has_changed, - event.room_id, - event.internal_metadata.stream_ordering, - ) + if stream_ordering is None: + stream_ordering = event.internal_metadata.stream_ordering + else: + stream_ordering = max( + stream_ordering, event.internal_metadata.stream_ordering + ) if not event.internal_metadata.is_outlier() and not context.rejected: - depth_updates[event.room_id] = max( - event.depth, depth_updates.get(event.room_id, event.depth) - ) + depth_update = max(event.depth, depth_update) + + # Then update the `stream_ordering` position to mark the latest event as + # the front of the room. + if stream_ordering is not None: + txn.call_after( + self.store._events_stream_cache.entity_has_changed, + room_id, + stream_ordering, + ) - for room_id, depth in depth_updates.items(): - self._update_min_depth_for_room_txn(txn, room_id, depth) + self._update_min_depth_for_room_txn(txn, room_id, depth_update) def _update_outliers_txn( self, @@ -1350,13 +1366,19 @@ class PersistEventsStore: PartialStateConflictError: if attempting to persist a partial state event in a room that has been un-partial stated. """ - txn.execute( - "SELECT event_id, outlier FROM events WHERE event_id in (%s)" - % (",".join(["?"] * len(events_and_contexts)),), - [event.event_id for event, _ in events_and_contexts], + rows = cast( + List[Tuple[str, bool]], + self.db_pool.simple_select_many_txn( + txn, + "events", + "event_id", + [event.event_id for event, _ in events_and_contexts], + keyvalues={}, + retcols=("event_id", "outlier"), + ), ) - have_persisted = dict(cast(Iterable[Tuple[str, bool]], txn)) + have_persisted = dict(rows) logger.debug( "_update_outliers_txn: events=%s have_persisted=%s", diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index aeb3db596c..c8d7c9fd32 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -26,6 +26,8 @@ from typing import ( cast, ) +import attr + from synapse.api.constants import Direction from synapse.logging.opentracing import trace from synapse.media._base import ThumbnailInfo @@ -45,6 +47,18 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( ) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class LocalMedia: + media_id: str + media_type: str + media_length: int + upload_name: str + created_ts: int + last_access_ts: int + quarantined_by: Optional[str] + safe_from_quarantine: bool + + class MediaSortOrder(Enum): """ Enum to define the sorting method used when returning media with @@ -180,7 +194,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): user_id: str, order_by: str = MediaSortOrder.CREATED_TS.value, direction: Direction = Direction.FORWARDS, - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> Tuple[List[LocalMedia], int]: """Get a paginated list of metadata for a local piece of media which an user_id has uploaded @@ -197,7 +211,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> Tuple[List[LocalMedia], int]: # Set ordering order_by_column = MediaSortOrder(order_by).value @@ -217,14 +231,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): sql = """ SELECT - "media_id", - "media_type", - "media_length", - "upload_name", - "created_ts", - "last_access_ts", - "quarantined_by", - "safe_from_quarantine" + media_id, + media_type, + media_length, + upload_name, + created_ts, + last_access_ts, + quarantined_by, + safe_from_quarantine FROM local_media_repository WHERE user_id = ? ORDER BY {order_by_column} {order}, media_id ASC @@ -236,7 +250,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): args += [limit, start] txn.execute(sql, args) - media = self.db_pool.cursor_to_dict(txn) + media = [ + LocalMedia( + media_id=row[0], + media_type=row[1], + media_length=row[2], + upload_name=row[3], + created_ts=row[4], + last_access_ts=row[5], + quarantined_by=row[6], + safe_from_quarantine=bool(row[7]), + ) + for row in txn + ] return media, count return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e09ab21593..933d76e905 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -1517,7 +1517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): async def get_registration_tokens( self, valid: Optional[bool] = None - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]: """List all registration tokens. Used by the admin API. Args: @@ -1526,34 +1526,48 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Default is None: return all tokens regardless of validity. Returns: - A list of dicts, each containing details of a token. + A list of tuples containing: + * The token + * The number of users allowed (or None) + * Whether it is pending + * Whether it has been completed + * An expiry time (or None if no expiry) """ def select_registration_tokens_txn( txn: LoggingTransaction, now: int, valid: Optional[bool] - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]: if valid is None: # Return all tokens regardless of validity - txn.execute("SELECT * FROM registration_tokens") + txn.execute( + """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + """ + ) elif valid: # Select valid tokens only - sql = ( - "SELECT * FROM registration_tokens WHERE " - "(uses_allowed > pending + completed OR uses_allowed IS NULL) " - "AND (expiry_time > ? OR expiry_time IS NULL)" - ) + sql = """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + WHERE (uses_allowed > pending + completed OR uses_allowed IS NULL) + AND (expiry_time > ? OR expiry_time IS NULL) + """ txn.execute(sql, [now]) else: # Select invalid tokens only - sql = ( - "SELECT * FROM registration_tokens WHERE " - "uses_allowed <= pending + completed OR expiry_time <= ?" - ) + sql = """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + WHERE uses_allowed <= pending + completed OR expiry_time <= ? + """ txn.execute(sql, [now]) - return self.db_pool.cursor_to_dict(txn) + return cast( + List[Tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall() + ) return await self.db_pool.runInteraction( "select_registration_tokens", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 3e8fcf1975..6d4b9891e7 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -78,6 +78,31 @@ class RatelimitOverride: burst_count: int +@attr.s(slots=True, frozen=True, auto_attribs=True) +class LargestRoomStats: + room_id: str + name: Optional[str] + canonical_alias: Optional[str] + joined_members: int + join_rules: Optional[str] + guest_access: Optional[str] + history_visibility: Optional[str] + state_events: int + avatar: Optional[str] + topic: Optional[str] + room_type: Optional[str] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomStats(LargestRoomStats): + joined_local_members: int + version: Optional[str] + creator: Optional[str] + encryption: Optional[str] + federatable: bool + public: bool + + class RoomSortOrder(Enum): """ Enum to define the sorting method used when returning rooms with get_rooms_paginate @@ -204,7 +229,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): allow_none=True, ) - async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: + async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: """Retrieve room with statistics. Args: @@ -215,7 +240,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def get_room_with_stats_txn( txn: LoggingTransaction, room_id: str - ) -> Optional[Dict[str, Any]]: + ) -> Optional[RoomStats]: sql = """ SELECT room_id, state.name, state.canonical_alias, curr.joined_members, curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, @@ -229,15 +254,28 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): WHERE room_id = ? """ txn.execute(sql, [room_id]) - # Catch error if sql returns empty result to return "None" instead of an error - try: - res = self.db_pool.cursor_to_dict(txn)[0] - except IndexError: + row = txn.fetchone() + if not row: return None - - res["federatable"] = bool(res["federatable"]) - res["public"] = bool(res["public"]) - return res + return RoomStats( + room_id=row[0], + name=row[1], + canonical_alias=row[2], + joined_members=row[3], + joined_local_members=row[4], + version=row[5], + creator=row[6], + encryption=row[7], + federatable=bool(row[8]), + public=bool(row[9]), + join_rules=row[10], + guest_access=row[11], + history_visibility=row[12], + state_events=row[13], + avatar=row[14], + topic=row[15], + room_type=row[16], + ) return await self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id @@ -368,7 +406,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): bounds: Optional[Tuple[int, str]], forwards: bool, ignore_non_federatable: bool = False, - ) -> List[Dict[str, Any]]: + ) -> List[LargestRoomStats]: """Gets the largest public rooms (where largest is in terms of joined members, as tracked in the statistics table). @@ -505,20 +543,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _get_largest_public_rooms_txn( txn: LoggingTransaction, - ) -> List[Dict[str, Any]]: + ) -> List[LargestRoomStats]: txn.execute(sql, query_args) - results = self.db_pool.cursor_to_dict(txn) + results = [ + LargestRoomStats( + room_id=r[0], + name=r[1], + canonical_alias=r[3], + joined_members=r[4], + join_rules=r[8], + guest_access=r[7], + history_visibility=r[6], + state_events=0, + avatar=r[5], + topic=r[2], + room_type=r[9], + ) + for r in txn + ] if not forwards: results.reverse() return results - ret_val = await self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) - return ret_val @cached(max_entries=10000) async def is_room_blocked(self, room_id: str) -> Optional[bool]: