diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/controllers/persist_events.py | 30 | ||||
-rw-r--r-- | synapse/storage/controllers/state.py | 42 | ||||
-rw-r--r-- | synapse/storage/databases/main/account_data.py | 3 | ||||
-rw-r--r-- | synapse/storage/databases/main/event_federation.py | 9 | ||||
-rw-r--r-- | synapse/storage/databases/main/event_push_actions.py | 303 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 2 | ||||
-rw-r--r-- | synapse/storage/databases/main/events_worker.py | 38 | ||||
-rw-r--r-- | synapse/storage/databases/main/push_rule.py | 127 | ||||
-rw-r--r-- | synapse/storage/databases/main/receipts.py | 2 | ||||
-rw-r--r-- | synapse/storage/databases/main/registration.py | 2 | ||||
-rw-r--r-- | synapse/storage/databases/main/room.py | 6 | ||||
-rw-r--r-- | synapse/storage/databases/main/roommember.py | 126 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 13 | ||||
-rw-r--r-- | synapse/storage/util/partial_state_events_tracker.py | 3 |
14 files changed, 387 insertions, 319 deletions
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index cf98b0ab48..dad3731b9b 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -45,8 +45,14 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.logging import opentracing from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.logging.opentracing import ( + SynapseTags, + active_span, + set_tag, + start_active_span_follows_from, + trace, +) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases @@ -223,7 +229,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): queue.append(end_item) # also add our active opentracing span to the item so that we get a link back - span = opentracing.active_span() + span = active_span() if span: end_item.parent_opentracing_span_contexts.append(span.context) @@ -234,7 +240,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): res = await make_deferred_yieldable(end_item.deferred.observe()) # add another opentracing span which links to the persist trace. - with opentracing.start_active_span_follows_from( + with start_active_span_follows_from( f"{task.name}_complete", (end_item.opentracing_span_context,) ): pass @@ -266,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): queue = self._get_drainining_queue(room_id) for item in queue: try: - with opentracing.start_active_span_follows_from( + with start_active_span_follows_from( item.task.name, item.parent_opentracing_span_contexts, inherit_force_tracing=True, @@ -355,7 +361,7 @@ class EventsPersistenceStorageController: f"Found an unexpected task type in event persistence queue: {task}" ) - @opentracing.trace + @trace async def persist_events( self, events_and_contexts: Iterable[Tuple[EventBase, EventContext]], @@ -380,9 +386,21 @@ class EventsPersistenceStorageController: PartialStateConflictError: if attempting to persist a partial state event in a room that has been un-partial stated. """ + event_ids: List[str] = [] partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {} for event, ctx in events_and_contexts: partitioned.setdefault(event.room_id, []).append((event, ctx)) + event_ids.append(event.event_id) + + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids", + str(event_ids), + ) + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled)) async def enqueue( item: Tuple[str, List[Tuple[EventBase, EventContext]]] @@ -418,7 +436,7 @@ class EventsPersistenceStorageController: self.main_store.get_room_max_token(), ) - @opentracing.trace + @trace async def persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]: diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 0d480f1014..f9ffd0e29e 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -29,7 +29,8 @@ from typing import ( from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.logging.opentracing import trace +from synapse.logging.opentracing import tag_args, trace +from synapse.storage.roommember import ProfileInfo from synapse.storage.state import StateFilter from synapse.storage.util.partial_state_events_tracker import ( PartialCurrentStateTracker, @@ -228,10 +229,12 @@ class StateStorageController: return {event: event_to_state[event] for event in event_ids} @trace + @tag_args async def get_state_ids_for_events( self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids @@ -240,6 +243,9 @@ class StateStorageController: Args: event_ids: events whose state should be returned state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at these events and `state_filter` is not satisfied by partial state. + Defaults to `True`. Returns: A dict from event_id -> (type, state_key) -> event_id @@ -248,8 +254,12 @@ class StateStorageController: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + if ( + await_full_state + and state_filter + and not state_filter.must_await_full_state(self._is_mine_id) + ): + # Full state is not required if the state filter is restrictive enough. await_full_state = False event_to_groups = await self.get_state_group_for_events( @@ -292,7 +302,10 @@ class StateStorageController: @trace async def get_state_ids_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None + self, + event_id: str, + state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """ Get the state dict corresponding to a particular event @@ -300,6 +313,9 @@ class StateStorageController: Args: event_id: event whose state should be returned state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at the event and `state_filter` is not satisfied by partial state. + Defaults to `True`. Returns: A dict from (type, state_key) -> state_event_id @@ -309,7 +325,9 @@ class StateStorageController: outlier or is unknown) """ state_map = await self.get_state_ids_for_events( - [event_id], state_filter or StateFilter.all() + [event_id], + state_filter or StateFilter.all(), + await_full_state=await_full_state, ) return state_map[event_id] @@ -332,6 +350,7 @@ class StateStorageController: ) @trace + @tag_args async def get_state_group_for_events( self, event_ids: Collection[str], @@ -473,6 +492,7 @@ class StateStorageController: prev_stream_id, max_stream_id ) + @trace async def get_current_state( self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[EventBase]: @@ -506,3 +526,15 @@ class StateStorageController: await self._partial_state_room_tracker.await_full_state(room_id) return await self.stores.main.get_current_hosts_in_room(room_id) + + async def get_users_in_room_with_profiles( + self, room_id: str + ) -> Dict[str, ProfileInfo]: + """ + Get the current users in the room with their profiles. + If the room is currently partial-stated, this will block until the room has + full state. + """ + await self._partial_state_room_tracker.await_full_state(room_id) + + return await self.stores.main.get_users_in_room_with_profiles(room_id) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 9af9f4f18e..c38b8a9e5a 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -650,9 +650,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) txn, self.get_account_data_for_room, (user_id,) ) self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,)) - self._invalidate_cache_and_stream( - txn, self.get_push_rules_enabled_for_user, (user_id,) - ) # This user might be contained in the ignored_by cache for other users, # so we have to invalidate it all. self._invalidate_all_cache_and_stream(txn, self.ignored_by) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index eec55b6478..c836078da6 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.events import EventBase, make_event_from_dict +from synapse.logging.opentracing import tag_args, trace from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -126,6 +127,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) return await self.get_events_as_list(event_ids) + @trace + @tag_args async def get_auth_chain_ids( self, room_id: str, @@ -709,6 +712,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} + @trace + @tag_args async def get_oldest_event_ids_with_depth_in_room( self, room_id: str ) -> List[Tuple[str, int]]: @@ -767,6 +772,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas room_id, ) + @trace async def get_insertion_event_backward_extremities_in_room( self, room_id: str ) -> List[Tuple[str, int]]: @@ -1339,6 +1345,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas event_results.reverse() return event_results + @trace + @tag_args async def get_successor_events(self, event_id: str) -> List[str]: """Fetch all events that have the given event as a prev event @@ -1375,6 +1383,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas _delete_old_forward_extrem_cache_txn, ) + @trace async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None: await self.db_pool.simple_upsert( table="insertion_event_extremities", diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 161aad0f89..8dfa545c27 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -74,7 +74,17 @@ receipt. """ import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, + cast, +) import attr @@ -154,7 +164,9 @@ class NotifCounts: highlight_count: int = 0 -def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str: +def _serialize_action( + actions: Collection[Union[Mapping, str]], is_highlight: bool +) -> str: """Custom serializer for actions. This allows us to "compress" common actions. We use the fact that most users have the same actions for notifs (and for @@ -227,7 +239,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: str, ) -> NotifCounts: """Get the notification count, the highlight count and the unread message count - for a given user in a given room after the given read receipt. + for a given user in a given room after their latest read receipt. Note that this function assumes the user to be a current member of the room, since it's either called by the sync handler to handle joined room entries, or by @@ -238,9 +250,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: The user to retrieve the counts for. Returns - A dict containing the counts mentioned earlier in this docstring, - respectively under the keys "notify_count", "highlight_count" and - "unread_count". + A NotifCounts object containing the notification count, the highlight count + and the unread message count. """ return await self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", @@ -255,6 +266,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas room_id: str, user_id: str, ) -> NotifCounts: + # Get the stream ordering of the user's latest receipt in the room. result = self.get_last_receipt_for_user_txn( txn, user_id, @@ -266,13 +278,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ), ) - stream_ordering = None if result: _, stream_ordering = result - if stream_ordering is None: - # Either last_read_event_id is None, or it's an event we don't have (e.g. - # because it's been purged), in which case retrieve the stream ordering for + else: + # If the user has no receipts in the room, retrieve the stream ordering for # the latest membership event from this user in this room (which we assume is # a join). event_id = self.db_pool.simple_select_one_onecol_txn( @@ -289,10 +299,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) def _get_unread_counts_by_pos_txn( - self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int + self, + txn: LoggingTransaction, + room_id: str, + user_id: str, + receipt_stream_ordering: int, ) -> NotifCounts: """Get the number of unread messages for a user/room that have happened since the given stream ordering. + + Args: + txn: The database transaction. + room_id: The room ID to get unread counts for. + user_id: The user ID to get unread counts for. + receipt_stream_ordering: The stream ordering of the user's latest + receipt in the room. If there are no receipts, the stream ordering + of the user's join event. + + Returns + A NotifCounts object containing the notification count, the highlight count + and the unread message count. """ counts = NotifCounts() @@ -320,7 +346,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas OR last_receipt_stream_ordering = ? ) """, - (room_id, user_id, stream_ordering, stream_ordering), + (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering), ) row = txn.fetchone() @@ -338,17 +364,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas AND stream_ordering > ? AND highlight = 1 """ - txn.execute(sql, (user_id, room_id, stream_ordering)) + txn.execute(sql, (user_id, room_id, receipt_stream_ordering)) row = txn.fetchone() if row: counts.highlight_count += row[0] # Finally we need to count push actions that aren't included in the - # summary returned above, e.g. recent events that haven't been - # summarised yet, or the summary is empty due to a recent read receipt. - stream_ordering = max(stream_ordering, summary_stream_ordering) + # summary returned above. This might be due to recent events that haven't + # been summarised yet or the summary is out of date due to a recent read + # receipt. + start_unread_stream_ordering = max( + receipt_stream_ordering, summary_stream_ordering + ) notify_count, unread_count = self._get_notif_unread_count_for_user_room( - txn, room_id, user_id, stream_ordering + txn, room_id, user_id, start_unread_stream_ordering ) counts.notify_count += notify_count @@ -430,6 +459,32 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return await self.db_pool.runInteraction("get_push_action_users_in_range", f) + def _get_receipts_by_room_txn( + self, txn: LoggingTransaction, user_id: str + ) -> List[Tuple[str, int]]: + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + ( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ), + ) + + sql = f""" + SELECT room_id, MAX(stream_ordering) + FROM receipts_linearized + INNER JOIN events USING (room_id, event_id) + WHERE {receipt_types_clause} + AND user_id = ? + GROUP BY room_id + """ + + args.extend((user_id,)) + txn.execute(sql, args) + return cast(List[Tuple[str, int]], txn.fetchall()) + async def get_unread_push_actions_for_user_in_range_for_http( self, user_id: str, @@ -453,106 +508,45 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - # find rooms that have a read receipt in them and return the next - # push actions - def get_after_receipt( - txn: LoggingTransaction, - ) -> List[Tuple[str, str, int, str, bool]]: - # find rooms that have a read receipt in them and return the next - # push actions - - receipt_types_clause, args = make_in_list_sql_clause( - self.database_engine, - "receipt_type", - ( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ), - ) - - sql = f""" - SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, - ep.highlight - FROM ( - SELECT room_id, - MAX(stream_ordering) as stream_ordering - FROM events - INNER JOIN receipts_linearized USING (room_id, event_id) - WHERE {receipt_types_clause} AND user_id = ? - GROUP BY room_id - ) AS rl, - event_push_actions AS ep - WHERE - ep.room_id = rl.room_id - AND ep.stream_ordering > rl.stream_ordering - AND ep.user_id = ? - AND ep.stream_ordering > ? - AND ep.stream_ordering <= ? - AND ep.notif = 1 - ORDER BY ep.stream_ordering ASC LIMIT ? - """ - args.extend( - (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) - ) - txn.execute(sql, args) - return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) - - after_read_receipt = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt + receipts_by_room = dict( + await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, + ), ) - # There are rooms with push actions in them but you don't have a read receipt in - # them e.g. rooms you've been invited to, so get push actions for rooms which do - # not have read receipts in them too. - def get_no_receipt( + def get_push_actions_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool]]: - receipt_types_clause, args = make_in_list_sql_clause( - self.database_engine, - "receipt_type", - ( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ), - ) - - sql = f""" - SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, - ep.highlight + sql = """ + SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight FROM event_push_actions AS ep - INNER JOIN events AS e USING (room_id, event_id) WHERE - ep.room_id NOT IN ( - SELECT room_id FROM receipts_linearized - WHERE {receipt_types_clause} AND user_id = ? - GROUP BY room_id - ) - AND ep.user_id = ? + ep.user_id = ? AND ep.stream_ordering > ? AND ep.stream_ordering <= ? AND ep.notif = 1 ORDER BY ep.stream_ordering ASC LIMIT ? """ - args.extend( - (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) - ) - txn.execute(sql, args) + txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) - no_read_receipt = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt + push_actions = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn ) notifs = [ HttpPushAction( - event_id=row[0], - room_id=row[1], - stream_ordering=row[2], - actions=_deserialize_action(row[3], row[4]), + event_id=event_id, + room_id=room_id, + stream_ordering=stream_ordering, + actions=_deserialize_action(actions, highlight), ) - for row in after_read_receipt + no_read_receipt + for event_id, room_id, stream_ordering, actions, highlight in push_actions + # Only include push actions with a stream ordering after any receipt, or without any + # receipt present (invited to but never read rooms). + if stream_ordering > receipts_by_room.get(room_id, 0) ] # Now sort it so it's ordered correctly, since currently it will @@ -588,106 +582,49 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - # find rooms that have a read receipt in them and return the most recent - # push actions - def get_after_receipt( - txn: LoggingTransaction, - ) -> List[Tuple[str, str, int, str, bool, int]]: - receipt_types_clause, args = make_in_list_sql_clause( - self.database_engine, - "receipt_type", - ( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ), - ) - - sql = f""" - SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, - ep.highlight, e.received_ts - FROM ( - SELECT room_id, - MAX(stream_ordering) as stream_ordering - FROM events - INNER JOIN receipts_linearized USING (room_id, event_id) - WHERE {receipt_types_clause} AND user_id = ? - GROUP BY room_id - ) AS rl, - event_push_actions AS ep - INNER JOIN events AS e USING (room_id, event_id) - WHERE - ep.room_id = rl.room_id - AND ep.stream_ordering > rl.stream_ordering - AND ep.user_id = ? - AND ep.stream_ordering > ? - AND ep.stream_ordering <= ? - AND ep.notif = 1 - ORDER BY ep.stream_ordering DESC LIMIT ? - """ - args.extend( - (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) - ) - txn.execute(sql, args) - return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) - - after_read_receipt = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt + receipts_by_room = dict( + await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, + ), ) - # There are rooms with push actions in them but you don't have a read receipt in - # them e.g. rooms you've been invited to, so get push actions for rooms which do - # not have read receipts in them too. - def get_no_receipt( + def get_push_actions_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool, int]]: - receipt_types_clause, args = make_in_list_sql_clause( - self.database_engine, - "receipt_type", - ( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ), - ) - - sql = f""" + sql = """ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight, e.received_ts FROM event_push_actions AS ep INNER JOIN events AS e USING (room_id, event_id) WHERE - ep.room_id NOT IN ( - SELECT room_id FROM receipts_linearized - WHERE {receipt_types_clause} AND user_id = ? - GROUP BY room_id - ) - AND ep.user_id = ? + ep.user_id = ? AND ep.stream_ordering > ? AND ep.stream_ordering <= ? AND ep.notif = 1 ORDER BY ep.stream_ordering DESC LIMIT ? """ - args.extend( - (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) - ) - txn.execute(sql, args) + txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) - no_read_receipt = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt + push_actions = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn ) # Make a list of dicts from the two sets of results. notifs = [ EmailPushAction( - event_id=row[0], - room_id=row[1], - stream_ordering=row[2], - actions=_deserialize_action(row[3], row[4]), - received_ts=row[5], + event_id=event_id, + room_id=room_id, + stream_ordering=stream_ordering, + actions=_deserialize_action(actions, highlight), + received_ts=received_ts, ) - for row in after_read_receipt + no_read_receipt + for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions + # Only include push actions with a stream ordering after any receipt, or without any + # receipt present (invited to but never read rooms). + if stream_ordering > receipts_by_room.get(room_id, 0) ] # Now sort it so it's ordered correctly, since currently it will @@ -733,7 +670,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas async def add_push_actions_to_staging( self, event_id: str, - user_id_actions: Dict[str, List[Union[dict, str]]], + user_id_actions: Dict[str, Collection[Union[Mapping, str]]], count_as_unread: bool, ) -> None: """Add the push actions for the event to the push action staging area. @@ -750,7 +687,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # This is a helper function for generating the necessary tuple that # can be used to insert into the `event_push_actions_staging` table. def _gen_entry( - user_id: str, actions: List[Union[dict, str]] + user_id: str, actions: Collection[Union[Mapping, str]] ) -> Tuple[str, str, str, int, int, int]: is_highlight = 1 if _action_has_highlight(actions) else 0 notif = 1 if "notify" in actions else 0 @@ -1151,8 +1088,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn: The database transaction. old_rotate_stream_ordering: The previous maximum event stream ordering. rotate_to_stream_ordering: The new maximum event stream ordering to summarise. - - Returns whether the archiving process has caught up or not. """ # Calculate the new counts that should be upserted into event_push_summary @@ -1238,9 +1173,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas (rotate_to_stream_ordering,), ) - async def _remove_old_push_actions_that_have_rotated( - self, - ) -> None: + async def _remove_old_push_actions_that_have_rotated(self) -> None: """Clear out old push actions that have been summarised.""" # We want to clear out anything that is older than a day that *has* already @@ -1397,7 +1330,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ] -def _action_has_highlight(actions: List[Union[dict, str]]) -> bool: +def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool: for action in actions: if not isinstance(action, dict): continue diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5560b38a48..a4010ee28d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -40,6 +40,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext +from synapse.logging.opentracing import trace from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -145,6 +146,7 @@ class PersistEventsStore: self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen + @trace async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index b07d812ae2..8a7cdb024d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -54,6 +54,7 @@ from synapse.logging.context import ( current_context, make_deferred_yieldable, ) +from synapse.logging.opentracing import start_active_span, tag_args, trace from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -430,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore): return {e.event_id: e for e in events} + @trace + @tag_args async def get_events_as_list( self, event_ids: Collection[str], @@ -1090,23 +1093,42 @@ class EventsWorkerStore(SQLBaseStore): """ fetched_event_ids: Set[str] = set() fetched_events: Dict[str, _EventRow] = {} - events_to_fetch = event_ids - while events_to_fetch: - row_map = await self._enqueue_events(events_to_fetch) + async def _fetch_event_ids_and_get_outstanding_redactions( + event_ids_to_fetch: Collection[str], + ) -> Collection[str]: + """ + Fetch all of the given event_ids and return any associated redaction event_ids + that we still need to fetch in the next iteration. + """ + row_map = await self._enqueue_events(event_ids_to_fetch) # we need to recursively fetch any redactions of those events redaction_ids: Set[str] = set() - for event_id in events_to_fetch: + for event_id in event_ids_to_fetch: row = row_map.get(event_id) fetched_event_ids.add(event_id) if row: fetched_events[event_id] = row redaction_ids.update(row.redactions) - events_to_fetch = redaction_ids.difference(fetched_event_ids) - if events_to_fetch: - logger.debug("Also fetching redaction events %s", events_to_fetch) + event_ids_to_fetch = redaction_ids.difference(fetched_event_ids) + return event_ids_to_fetch + + # Grab the initial list of events requested + event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions( + event_ids + ) + # Then go and recursively find all of the associated redactions + with start_active_span("recursively fetching redactions"): + while event_ids_to_fetch: + logger.debug("Also fetching redaction events %s", event_ids_to_fetch) + + event_ids_to_fetch = ( + await _fetch_event_ids_and_get_outstanding_redactions( + event_ids_to_fetch + ) + ) # build a map from event_id to EventBase event_map: Dict[str, EventBase] = {} @@ -1424,6 +1446,8 @@ class EventsWorkerStore(SQLBaseStore): return {r["event_id"] for r in rows} + @trace + @tag_args async def have_seen_events( self, room_id: str, event_ids: Iterable[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 768f95d16c..5079edd1e0 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -14,11 +14,23 @@ # limitations under the License. import abc import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, + cast, +) from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig -from synapse.push.baserules import list_with_base_rules +from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -50,60 +62,30 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def _is_experimental_rule_enabled( - rule_id: str, experimental_config: ExperimentalConfig -) -> bool: - """Used by `_load_rules` to filter out experimental rules when they - have not been enabled. - """ - if ( - rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl" - and not experimental_config.msc3786_enabled - ): - return False - if ( - rule_id == "global/underride/.org.matrix.msc3772.thread_reply" - and not experimental_config.msc3772_enabled - ): - return False - return True - - def _load_rules( rawrules: List[JsonDict], enabled_map: Dict[str, bool], experimental_config: ExperimentalConfig, -) -> List[JsonDict]: - ruleslist = [] - for rawrule in rawrules: - rule = dict(rawrule) - rule["conditions"] = db_to_json(rawrule["conditions"]) - rule["actions"] = db_to_json(rawrule["actions"]) - rule["default"] = False - ruleslist.append(rule) - - # We're going to be mutating this a lot, so copy it. We also filter out - # any experimental default push rules that aren't enabled. - rules = [ - rule - for rule in list_with_base_rules(ruleslist) - if _is_experimental_rule_enabled(rule["rule_id"], experimental_config) - ] +) -> FilteredPushRules: + """Take the DB rows returned from the DB and convert them into a full + `FilteredPushRules` object. + """ - for i, rule in enumerate(rules): - rule_id = rule["rule_id"] + ruleslist = [ + PushRule( + rule_id=rawrule["rule_id"], + priority_class=rawrule["priority_class"], + conditions=db_to_json(rawrule["conditions"]), + actions=db_to_json(rawrule["actions"]), + ) + for rawrule in rawrules + ] - if rule_id not in enabled_map: - continue - if rule.get("enabled", True) == bool(enabled_map[rule_id]): - continue + push_rules = compile_push_rules(ruleslist) - # Rules are cached across users. - rule = dict(rule) - rule["enabled"] = bool(enabled_map[rule_id]) - rules[i] = rule + filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config) - return rules + return filtered_rules # The ABCMeta metaclass ensures that it cannot be instantiated without @@ -162,7 +144,7 @@ class PushRulesWorkerStore( raise NotImplementedError() @cached(max_entries=5000) - async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]: + async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: rows = await self.db_pool.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, @@ -183,7 +165,6 @@ class PushRulesWorkerStore( return _load_rules(rows, enabled_map, self.hs.config.experimental) - @cached(max_entries=5000) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( table="push_rules_enable", @@ -216,11 +197,11 @@ class PushRulesWorkerStore( @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids") async def bulk_get_push_rules( self, user_ids: Collection[str] - ) -> Dict[str, List[JsonDict]]: + ) -> Dict[str, FilteredPushRules]: if not user_ids: return {} - results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids} + raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules", @@ -234,20 +215,19 @@ class PushRulesWorkerStore( rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) for row in rows: - results.setdefault(row["user_name"], []).append(row) + raw_rules.setdefault(row["user_name"], []).append(row) enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) - for user_id, rules in results.items(): + results: Dict[str, FilteredPushRules] = {} + + for user_id, rules in raw_rules.items(): results[user_id] = _load_rules( rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental ) return results - @cachedList( - cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids" - ) async def bulk_get_push_rules_enabled( self, user_ids: Collection[str] ) -> Dict[str, Dict[str, bool]]: @@ -262,6 +242,7 @@ class PushRulesWorkerStore( iterable=user_ids, retcols=("user_name", "rule_id", "enabled"), desc="bulk_get_push_rules_enabled", + batch_size=1000, ) for row in rows: enabled = bool(row["enabled"]) @@ -345,8 +326,8 @@ class PushRuleStore(PushRulesWorkerStore): user_id: str, rule_id: str, priority_class: int, - conditions: List[Dict[str, str]], - actions: List[Union[JsonDict, str]], + conditions: Sequence[Mapping[str, str]], + actions: Sequence[Union[Mapping[str, Any], str]], before: Optional[str] = None, after: Optional[str] = None, ) -> None: @@ -808,7 +789,6 @@ class PushRuleStore(PushRulesWorkerStore): self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values) txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) - txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) txn.call_after( self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) @@ -817,7 +797,7 @@ class PushRuleStore(PushRulesWorkerStore): return self._push_rules_stream_id_gen.get_current_token() async def copy_push_rule_from_room_to_room( - self, new_room_id: str, user_id: str, rule: dict + self, new_room_id: str, user_id: str, rule: PushRule ) -> None: """Copy a single push rule from one room to another for a specific user. @@ -827,21 +807,27 @@ class PushRuleStore(PushRulesWorkerStore): rule: A push rule. """ # Create new rule id - rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) + rule_id_scope = "/".join(rule.rule_id.split("/")[:-1]) new_rule_id = rule_id_scope + "/" + new_room_id + new_conditions = [] + # Change room id in each condition - for condition in rule.get("conditions", []): + for condition in rule.conditions: + new_condition = condition if condition.get("key") == "room_id": - condition["pattern"] = new_room_id + new_condition = dict(condition) + new_condition["pattern"] = new_room_id + + new_conditions.append(new_condition) # Add the rule for the new room await self.add_push_rule( user_id=user_id, rule_id=new_rule_id, - priority_class=rule["priority_class"], - conditions=rule["conditions"], - actions=rule["actions"], + priority_class=rule.priority_class, + conditions=new_conditions, + actions=rule.actions, ) async def copy_push_rules_from_room_to_room_for_user( @@ -859,8 +845,11 @@ class PushRuleStore(PushRulesWorkerStore): user_push_rules = await self.get_push_rules_for_user(user_id) # Get rules relating to the old room and copy them to the new room - for rule in user_push_rules: - conditions = rule.get("conditions", []) + for rule, enabled in user_push_rules: + if not enabled: + continue + + conditions = rule.conditions if any( (c.get("key") == "room_id" and c.get("pattern") == old_room_id) for c in conditions diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 0090c9f225..124c70ad37 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -161,7 +161,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: The receipt types to fetch. Returns: - The latest receipt, if one exists. + The event ID and stream ordering of the latest receipt, if one exists. """ clause, args = make_in_list_sql_clause( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index cb63cd9b7d..7fb9c801da 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -69,9 +69,9 @@ class TokenLookupResult: """ user_id: str + token_id: int is_guest: bool = False shadow_banned: bool = False - token_id: Optional[int] = None device_id: Optional[str] = None valid_until_ms: Optional[int] = None token_owner: str = attr.ib() diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 0f1f0d11ea..b7d4baa6bb 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -2001,9 +2001,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + # We join on room_stats_state despite not using any columns from it + # because the join can influence the number of rows returned; + # e.g. a room that doesn't have state, maybe because it was deleted. + # The query returning the total count should be consistent with + # the query returning the results. sql = """ SELECT COUNT(*) as total_event_reports FROM event_reports AS er + JOIN room_stats_state ON room_stats_state.room_id = er.room_id {} """.format( where_clause diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 93ff4816c8..9e5034b401 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -283,6 +283,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): Returns: A mapping from user ID to ProfileInfo. + + Preconditions: + - There is full state available for the room (it is not partial-stated). """ def _get_users_in_room_with_profiles( @@ -531,6 +534,32 @@ class RoomMemberWorkerStore(EventsWorkerStore): desc="get_local_users_in_room", ) + async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool: + """ + Check whether a given local user is currently joined to the given room. + + Returns: + A boolean indicating whether the user is currently joined to the room + + Raises: + Exeption when called with a non-local user to this homeserver + """ + if not self.hs.is_mine_id(user_id): + raise Exception( + "Cannot call 'check_local_user_in_room' on " + "non-local user %s" % (user_id,), + ) + + ( + membership, + member_event_id, + ) = await self.get_local_current_membership_for_user_in_room( + user_id=user_id, + room_id=room_id, + ) + + return membership == Membership.JOIN + async def get_local_current_membership_for_user_in_room( self, user_id: str, room_id: str ) -> Tuple[Optional[str], Optional[str]]: @@ -832,9 +861,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): return shared_room_ids or frozenset() - async def get_joined_users_from_state( + async def get_joined_user_ids_from_state( self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry" - ) -> Dict[str, ProfileInfo]: + ) -> Set[str]: state_group: Union[object, int] = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -845,25 +874,25 @@ class RoomMemberWorkerStore(EventsWorkerStore): assert state_group is not None with Measure(self._clock, "get_joined_users_from_state"): - return await self._get_joined_users_from_context( + return await self._get_joined_user_ids_from_context( room_id, state_group, state, context=state_entry ) @cached(num_args=2, iterable=True, max_entries=100000) - async def _get_joined_users_from_context( + async def _get_joined_user_ids_from_context( self, room_id: str, state_group: Union[object, int], current_state_ids: StateMap[str], event: Optional[EventBase] = None, context: Optional["_StateCacheEntry"] = None, - ) -> Dict[str, ProfileInfo]: + ) -> Set[str]: # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different. assert state_group is not None - users_in_room = {} + users_in_room = set() member_event_ids = [ e_id for key, e_id in current_state_ids.items() @@ -876,11 +905,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): # If we do then we can reuse that result and simply update it with # any membership changes in `delta_ids` if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get_immediate( + prev_res = self._get_joined_user_ids_from_context.cache.get_immediate( (room_id, context.prev_group), None ) - if prev_res and isinstance(prev_res, dict): - users_in_room = dict(prev_res) + if prev_res and isinstance(prev_res, set): + users_in_room = prev_res member_event_ids = [ e_id for key, e_id in context.delta_ids.items() @@ -888,7 +917,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ] for etype, state_key in context.delta_ids: if etype == EventTypes.Member: - users_in_room.pop(state_key, None) + users_in_room.discard(state_key) # We check if we have any of the member event ids in the event cache # before we ask the DB @@ -905,71 +934,64 @@ class RoomMemberWorkerStore(EventsWorkerStore): ev_entry = event_map.get(event_id) if ev_entry and not ev_entry.event.rejected_reason: if ev_entry.event.membership == Membership.JOIN: - users_in_room[ev_entry.event.state_key] = ProfileInfo( - display_name=ev_entry.event.content.get("displayname", None), - avatar_url=ev_entry.event.content.get("avatar_url", None), - ) + users_in_room.add(ev_entry.event.state_key) else: missing_member_event_ids.append(event_id) if missing_member_event_ids: - event_to_memberships = await self._get_joined_profiles_from_event_ids( + event_to_memberships = await self._get_user_ids_from_membership_event_ids( missing_member_event_ids ) - users_in_room.update(row for row in event_to_memberships.values() if row) + users_in_room.update( + user_id for user_id in event_to_memberships.values() if user_id + ) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: if event.event_id in member_event_ids: - users_in_room[event.state_key] = ProfileInfo( - display_name=event.content.get("displayname", None), - avatar_url=event.content.get("avatar_url", None), - ) + users_in_room.add(event.state_key) return users_in_room - @cached(max_entries=10000) - def _get_joined_profile_from_event_id( + @cached( + max_entries=10000, + # This name matches the old function that has been replaced - the cache name + # is kept here to maintain backwards compatibility. + name="_get_joined_profile_from_event_id", + ) + def _get_user_id_from_membership_event_id( self, event_id: str ) -> Optional[Tuple[str, ProfileInfo]]: raise NotImplementedError() @cachedList( - cached_method_name="_get_joined_profile_from_event_id", + cached_method_name="_get_user_id_from_membership_event_id", list_name="event_ids", ) - async def _get_joined_profiles_from_event_ids( + async def _get_user_ids_from_membership_event_ids( self, event_ids: Iterable[str] - ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]: + ) -> Dict[str, Optional[str]]: """For given set of member event_ids check if they point to a join - event and if so return the associated user and profile info. + event. Args: event_ids: The member event IDs to lookup Returns: - Map from event ID to `user_id` and ProfileInfo (or None if not join event). + Map from event ID to `user_id`, or None if event is not a join. """ rows = await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, - retcols=("user_id", "display_name", "avatar_url", "event_id"), + retcols=("user_id", "event_id"), keyvalues={"membership": Membership.JOIN}, batch_size=1000, - desc="_get_joined_profiles_from_event_ids", + desc="_get_user_ids_from_membership_event_ids", ) - return { - row["event_id"]: ( - row["user_id"], - ProfileInfo( - avatar_url=row["avatar_url"], display_name=row["display_name"] - ), - ) - for row in rows - } + return {row["event_id"]: row["user_id"] for row in rows} @cached(max_entries=10000) async def is_host_joined(self, room_id: str, host: str) -> bool: @@ -1128,12 +1150,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): else: # The cache doesn't match the state group or prev state group, # so we calculate the result from first principles. - joined_users = await self.get_joined_users_from_state( + joined_user_ids = await self.get_joined_user_ids_from_state( room_id, state, state_entry ) cache.hosts_to_joined_users = {} - for user_id in joined_users: + for user_id in joined_user_ids: host = intern_string(get_domain_from_id(user_id)) cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) @@ -1212,6 +1234,30 @@ class RoomMemberWorkerStore(EventsWorkerStore): "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) + async def is_locally_forgotten_room(self, room_id: str) -> bool: + """Returns whether all local users have forgotten this room_id. + + Args: + room_id: The room ID to query. + + Returns: + Whether the room is forgotten. + """ + + sql = """ + SELECT count(*) > 0 FROM local_current_membership + INNER JOIN room_memberships USING (room_id, event_id) + WHERE + room_id = ? + AND forgotten = 0; + """ + + rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id) + + # `count(*)` returns always an integer + # If any rows still exist it means someone has not forgotten this room yet + return not rows[0][0] + async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]: """Get all rooms that the user has ever been in. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 3c13859faa..2dfe4c0b66 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -460,8 +460,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): # Cast safety: this corresponds to the types returned by the query above. rows.extend(cast(Iterable[Tuple[str, int]], cur)) - # Sort so that we handle rows in order for each instance. - rows.sort() + # Sort by stream_id (ascending, lowest -> highest) so that we handle + # rows in order for each instance because we don't want to overwrite + # the current_position of an instance to a lower stream ID than + # we're actually at. + def sort_by_stream_id_key_func(row: Tuple[str, int]) -> int: + (instance, stream_id) = row + # If `stream_id` is ever `None`, we will see a `TypeError: '<' + # not supported between instances of 'NoneType' and 'X'` error. + return stream_id + + rows.sort(key=sort_by_stream_id_key_func) with self._lock: for ( diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py index 466e5137f2..b4bf49dace 100644 --- a/synapse/storage/util/partial_state_events_tracker.py +++ b/synapse/storage/util/partial_state_events_tracker.py @@ -20,6 +20,7 @@ from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.logging.opentracing import trace_with_opname from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.room import RoomWorkerStore from synapse.util import unwrapFirstError @@ -58,6 +59,7 @@ class PartialStateEventsTracker: for o in observers: o.callback(None) + @trace_with_opname("PartialStateEventsTracker.await_full_state") async def await_full_state(self, event_ids: Collection[str]) -> None: """Wait for all the given events to have full state. @@ -151,6 +153,7 @@ class PartialCurrentStateTracker: for o in observers: o.callback(None) + @trace_with_opname("PartialCurrentStateTracker.await_full_state") async def await_full_state(self, room_id: str) -> None: # We add the deferred immediately so that the DB call to check for # partial state doesn't race when we unpartial the room. |