summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2024-01-12 12:20:17 +0000
committerErik Johnston <erik@matrix.org>2024-01-12 12:20:17 +0000
commit4df836af09cb0424757852b286a38b24543a3286 (patch)
tree8ccbcb4b2c0b0c43dfb838e2ffea3d0cddf502ea /synapse
parentMerge remote-tracking branch 'origin/develop' into matrix-org-hotfixes (diff)
parentUpdate license in Debian metadata (#16807) (diff)
downloadsynapse-matrix-org-hotfixes.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/oidc.py15
-rw-r--r--synapse/handlers/federation_event.py100
-rw-r--r--synapse/handlers/room_summary.py12
-rw-r--r--synapse/handlers/sync.py14
-rw-r--r--synapse/push/push_tools.py8
-rw-r--r--synapse/storage/databases/main/deviceinbox.py149
-rw-r--r--synapse/storage/databases/main/devices.py8
-rw-r--r--synapse/storage/databases/main/event_push_actions.py253
-rw-r--r--synapse/storage/databases/main/state.py63
-rw-r--r--synapse/types/state.py24
10 files changed, 385 insertions, 261 deletions
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py

index 07ca16c94c..8f9cdbddbb 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py
@@ -299,6 +299,19 @@ def _parse_oidc_config_dict( config_path + ("client_secret",), ) + # If no client secret is specified then the auth method must be None + client_auth_method = oidc_config.get("client_auth_method") + if client_secret is None and client_secret_jwt_key is None: + if client_auth_method is None: + client_auth_method = "none" + elif client_auth_method != "none": + raise ConfigError( + "No 'client_secret' is set in OIDC config, and 'client_auth_method' is not set to 'none'" + ) + + if client_auth_method is None: + client_auth_method = "client_secret_basic" + return OidcProviderConfig( idp_id=idp_id, idp_name=oidc_config.get("idp_name", "OIDC"), @@ -309,7 +322,7 @@ def _parse_oidc_config_dict( client_id=oidc_config["client_id"], client_secret=client_secret, client_secret_jwt_key=client_secret_jwt_key, - client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), + client_auth_method=client_auth_method, pkce_method=oidc_config.get("pkce_method", "auto"), scopes=oidc_config.get("scopes", ["openid"]), authorization_endpoint=oidc_config.get("authorization_endpoint"), diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 882be905db..12837429b9 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py
@@ -94,7 +94,7 @@ from synapse.types import ( ) from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer, concurrently_execute -from synapse.util.iterutils import batch_iter, partition, sorted_topologically_batched +from synapse.util.iterutils import batch_iter, partition, sorted_topologically from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import shortstr @@ -1141,16 +1141,8 @@ class FederationEventHandler: partial_state_flags = await self._store.get_partial_state_events(seen) partial_state = any(partial_state_flags.values()) - # Get the state of the events we know about - ours = await self._state_storage_controller.get_state_groups_ids( - room_id, seen, await_full_state=False - ) - # state_maps is a list of mappings from (type, state_key) to event_id - state_maps: List[StateMap[str]] = list(ours.values()) - - # we don't need this any more, let's delete it. - del ours + state_maps: List[StateMap[str]] = [] # Ask the remote server for the states we don't # know about @@ -1169,6 +1161,17 @@ class FederationEventHandler: state_maps.append(remote_state_map) + # Get the state of the events we know about. We do this *after* + # trying to fetch missing state over federation as that might fail + # and then we can skip loading the local state. + ours = await self._state_storage_controller.get_state_groups_ids( + room_id, seen, await_full_state=False + ) + state_maps.extend(ours.values()) + + # we don't need this any more, let's delete it. + del ours + room_version = await self._store.get_room_version_id(room_id) state_map = await self._state_resolution_handler.resolve_events_with_store( room_id, @@ -1678,57 +1681,36 @@ class FederationEventHandler: # We need to persist an event's auth events before the event. auth_graph = { - ev: [event_map[e_id] for e_id in ev.auth_event_ids() if e_id in event_map] + ev.event_id: [e_id for e_id in ev.auth_event_ids() if e_id in event_map] for ev in event_map.values() } - for roots in sorted_topologically_batched(event_map.values(), auth_graph): - if not roots: - # if *none* of the remaining events are ready, that means - # we have a loop. This either means a bug in our logic, or that - # somebody has managed to create a loop (which requires finding a - # hash collision in room v2 and later). - logger.warning( - "Loop found in auth events while fetching missing state/auth " - "events: %s", - shortstr(event_map.keys()), - ) - return - - logger.info( - "Persisting %i of %i remaining outliers: %s", - len(roots), - len(event_map), - shortstr(e.event_id for e in roots), - ) - - await self._auth_and_persist_outliers_inner(room_id, roots) - - async def _auth_and_persist_outliers_inner( - self, room_id: str, fetched_events: Collection[EventBase] - ) -> None: - """Helper for _auth_and_persist_outliers - - Persists a batch of events where we have (theoretically) already persisted all - of their auth events. - - Marks the events as outliers, auths them, persists them to the database, and, - where appropriate (eg, an invite), awakes the notifier. + sorted_auth_event_ids = sorted_topologically(event_map.keys(), auth_graph) + sorted_auth_events = [event_map[e_id] for e_id in sorted_auth_event_ids] + logger.info( + "Persisting %i remaining outliers: %s", + len(sorted_auth_events), + shortstr(e.event_id for e in sorted_auth_events), + ) - Params: - origin: where the events came from - room_id: the room that the events are meant to be in (though this has - not yet been checked) - fetched_events: the events to persist - """ # get all the auth events for all the events in this batch. By now, they should # have been persisted. - auth_events = { - aid for event in fetched_events for aid in event.auth_event_ids() + auth_event_ids = { + aid for event in sorted_auth_events for aid in event.auth_event_ids() } - persisted_events = await self._store.get_events( - auth_events, - allow_rejected=True, - ) + auth_map = { + ev.event_id: ev + for ev in sorted_auth_events + if ev.event_id in auth_event_ids + } + + missing_events = auth_event_ids.difference(auth_map) + if missing_events: + persisted_events = await self._store.get_events( + missing_events, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.as_is, + ) + auth_map.update(persisted_events) events_and_contexts_to_persist: List[Tuple[EventBase, EventContext]] = [] @@ -1736,7 +1718,7 @@ class FederationEventHandler: with nested_logging_context(suffix=event.event_id): auth = [] for auth_event_id in event.auth_event_ids(): - ae = persisted_events.get(auth_event_id) + ae = auth_map.get(auth_event_id) if not ae: # the fact we can't find the auth event doesn't mean it doesn't # exist, which means it is premature to reject `event`. Instead we @@ -1755,7 +1737,9 @@ class FederationEventHandler: context = EventContext.for_outlier(self._storage_controllers) try: validate_event_for_room_version(event) - await check_state_independent_auth_rules(self._store, event) + await check_state_independent_auth_rules( + self._store, event, batched_auth_events=auth_map + ) check_state_dependent_auth_rules(event, auth) except AuthError as e: logger.warning("Rejecting %r because %s", event, e) @@ -1772,7 +1756,7 @@ class FederationEventHandler: events_and_contexts_to_persist.append((event, context)) - for event in fetched_events: + for event in sorted_auth_events: await prep(event) await self.persist_events_and_notify( diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index a534f5f280..78bcac1429 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py
@@ -44,6 +44,7 @@ from synapse.api.ratelimiting import Ratelimiter from synapse.config.ratelimiting import RatelimitSettings from synapse.events import EventBase from synapse.types import JsonDict, Requester, StrCollection +from synapse.types.state import StateFilter from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -546,7 +547,16 @@ class RoomSummaryHandler: Returns: True if the room is accessible to the requesting user or server. """ - state_ids = await self._storage_controllers.state.get_current_state_ids(room_id) + event_types = [ + (EventTypes.JoinRules, ""), + (EventTypes.RoomHistoryVisibility, ""), + ] + if requester: + event_types.append((EventTypes.Member, requester)) + + state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id, state_filter=StateFilter.from_types(event_types) + ) # If there's no state for the room, it isn't known. if not state_ids: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 0385c04bc2..2e10035772 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -583,10 +583,11 @@ class SyncHandler: # `recents`, so partial state is only a problem when a membership # event turns up in `recents` but has not made it into the current # state. - current_state_ids_map = ( - await self.store.get_partial_current_state_ids(room_id) + current_state_ids = ( + await self.store.check_if_events_in_current_state( + {e.event_id for e in recents if e.is_state()} + ) ) - current_state_ids = frozenset(current_state_ids_map.values()) recents = await filter_events_for_client( self._storage_controllers, @@ -667,10 +668,11 @@ class SyncHandler: # `loaded_recents`, so partial state is only a problem when a # membership event turns up in `loaded_recents` but has not made it # into the current state. - current_state_ids_map = ( - await self.store.get_partial_current_state_ids(room_id) + current_state_ids = ( + await self.store.check_if_events_in_current_state( + {e.event_id for e in loaded_recents if e.is_state()} + ) ) - current_state_ids = frozenset(current_state_ids_map.values()) loaded_recents = await filter_events_for_client( self._storage_controllers, diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 03ce0b4dc6..cce9583fa7 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py
@@ -28,17 +28,11 @@ from synapse.storage.databases.main import DataStore async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int: invites = await store.get_invited_rooms_for_local_user(user_id) - joins = await store.get_rooms_for_user(user_id) badge = len(invites) room_to_count = await store.get_unread_counts_by_room_for_user(user_id) - for room_id, notify_count in room_to_count.items(): - # room_to_count may include rooms which the user has left, - # ignore those. - if room_id not in joins: - continue - + for _room_id, notify_count in room_to_count.items(): if notify_count == 0: continue diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 4bbcf7199c..5a1a3e8e65 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -245,33 +245,74 @@ class DeviceInboxWorkerStore(SQLBaseStore): * The last-processed stream ID. Subsequent calls of this function with the same device should pass this value as 'from_stream_id'. """ - ( - user_id_device_id_to_messages, - last_processed_stream_id, - ) = await self._get_device_messages( - user_ids=[user_id], - device_id=device_id, - from_stream_id=from_stream_id, - to_stream_id=to_stream_id, - limit=limit, - ) - - if not user_id_device_id_to_messages: + if not self._device_inbox_stream_cache.has_entity_changed( + user_id, from_stream_id + ): # There were no messages! return [], to_stream_id - # Extract the messages, no need to return the user and device ID again - to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), []) + def get_device_messages_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: + sql = """ + SELECT stream_id, message_json FROM device_inbox + WHERE user_id = ? AND device_id = ? + AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (user_id, device_id, from_stream_id, to_stream_id, limit)) + + # Create and fill a dictionary of (user ID, device ID) -> list of messages + # intended for each device. + last_processed_stream_pos = to_stream_id + to_device_messages: List[JsonDict] = [] + rowcount = 0 + for row in txn: + rowcount += 1 + + last_processed_stream_pos = row[0] + message_dict = db_to_json(row[1]) + + # Store the device details + to_device_messages.append(message_dict) - return to_device_messages, last_processed_stream_id + # start a new span for each message, so that we can tag each separately + with start_active_span("get_to_device_message"): + set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"]) + set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID), + ) + + if rowcount == limit: + # We ended up bumping up against the message limit. There may be more messages + # to retrieve. Return what we have, as well as the last stream position that + # was processed. + # + # The caller is expected to set this as the lower (exclusive) bound + # for the next query of this device. + return to_device_messages, last_processed_stream_pos + + # The limit was not reached, thus we know that recipient_device_to_messages + # contains all to-device messages for the given device and stream id range. + # + # We return to_stream_id, which the caller should then provide as the lower + # (exclusive) bound on the next query of this device. + return to_device_messages, to_stream_id + + return await self.db_pool.runInteraction( + "get_messages_for_device", get_device_messages_txn + ) async def _get_device_messages( self, user_ids: Collection[str], from_stream_id: int, to_stream_id: int, - device_id: Optional[str] = None, - limit: Optional[int] = None, ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: """ Retrieve pending to-device messages for a collection of user devices. @@ -291,11 +332,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): user_ids: The user IDs to filter device messages by. from_stream_id: The lower boundary of stream id to filter with (exclusive). to_stream_id: The upper boundary of stream id to filter with (inclusive). - device_id: A device ID to query to-device messages for. If not provided, to-device - messages from all device IDs for the given user IDs will be queried. May not be - provided if `user_ids` contains more than one entry. - limit: The maximum number of to-device messages to return. Can only be used when - passing a single user ID / device ID tuple. + Returns: A tuple containing: @@ -308,30 +345,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): logger.warning("No users provided upon querying for device IDs") return {}, to_stream_id - # Prevent a query for one user's device also retrieving another user's device with - # the same device ID (device IDs are not unique across users). - if len(user_ids) > 1 and device_id is not None: - raise AssertionError( - "Programming error: 'device_id' cannot be supplied to " - "_get_device_messages when >1 user_id has been provided" - ) - - # A limit can only be applied when querying for a single user ID / device ID tuple. - # See the docstring of this function for more details. - if limit is not None and device_id is None: - raise AssertionError( - "Programming error: _get_device_messages was passed 'limit' " - "without a specific user_id/device_id" - ) - user_ids_to_query: Set[str] = set() - device_ids_to_query: Set[str] = set() - - # Note that a device ID could be an empty str - if device_id is not None: - # If a device ID was passed, use it to filter results. - # Otherwise, device IDs will be derived from the given collection of user IDs. - device_ids_to_query.add(device_id) # Determine which users have devices with pending messages for user_id in user_ids: @@ -355,20 +369,20 @@ class DeviceInboxWorkerStore(SQLBaseStore): # hidden devices should not receive to-device messages. # Note that this is more efficient than just dropping `device_id` from the query, # since device_inbox has an index on `(user_id, device_id, stream_id)` - if not device_ids_to_query: - user_device_dicts = cast( - List[Tuple[str]], - self.db_pool.simple_select_many_txn( - txn, - table="devices", - column="user_id", - iterable=user_ids_to_query, - keyvalues={"hidden": False}, - retcols=("device_id",), - ), - ) - device_ids_to_query.update({row[0] for row in user_device_dicts}) + user_device_dicts = cast( + List[Tuple[str]], + self.db_pool.simple_select_many_txn( + txn, + table="devices", + column="user_id", + iterable=user_ids_to_query, + keyvalues={"hidden": False}, + retcols=("device_id",), + ), + ) + + device_ids_to_query = {row[0] for row in user_device_dicts} if not device_ids_to_query: # We've ended up with no devices to query. @@ -400,22 +414,15 @@ class DeviceInboxWorkerStore(SQLBaseStore): to_stream_id, ) - # If a limit was provided, limit the data retrieved from the database - if limit is not None: - sql += "LIMIT ?" - sql_args += (limit,) - txn.execute(sql, sql_args) # Create and fill a dictionary of (user ID, device ID) -> list of messages # intended for each device. - last_processed_stream_pos = to_stream_id recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {} rowcount = 0 for row in txn: rowcount += 1 - last_processed_stream_pos = row[0] recipient_user_id = row[1] recipient_device_id = row[2] message_dict = db_to_json(row[3]) @@ -436,18 +443,6 @@ class DeviceInboxWorkerStore(SQLBaseStore): message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID), ) - if limit is not None and rowcount == limit: - # We ended up bumping up against the message limit. There may be more messages - # to retrieve. Return what we have, as well as the last stream position that - # was processed. - # - # The caller is expected to set this as the lower (exclusive) bound - # for the next query of this device. - return recipient_device_to_messages, last_processed_stream_pos - - # The limit was not reached, thus we know that recipient_device_to_messages - # contains all to-device messages for the given device and stream id range. - # # We return to_stream_id, which the caller should then provide as the lower # (exclusive) bound on the next query of this device. return recipient_device_to_messages, to_stream_id diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 38029710db..d3859014b6 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -1796,7 +1796,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_ids: The IDs of the devices to delete """ - def _delete_devices_txn(txn: LoggingTransaction) -> None: + def _delete_devices_txn(txn: LoggingTransaction, device_ids: List[str]) -> None: self.db_pool.simple_delete_many_txn( txn, table="devices", @@ -1813,7 +1813,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): keyvalues={"user_id": user_id}, ) - await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) + for batch in batch_iter(device_ids, 100): + await self.db_pool.runInteraction( + "delete_devices", _delete_devices_txn, batch + ) + for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 650b8c8135..6d4e2942ea 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -357,10 +357,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas This function is intentionally not cached because it is called to calculate the unread badge for push notifications and thus the result is expected to change. - Note that this function assumes the user is a member of the room. Because - summary rows are not removed when a user leaves a room, the caller must - filter out those results from the result. - Returns: A map of room ID to notification counts for the given user. """ @@ -373,127 +369,170 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def _get_unread_counts_by_room_for_user_txn( self, txn: LoggingTransaction, user_id: str ) -> Dict[str, int]: - receipt_types_clause, args = make_in_list_sql_clause( + # To get the badge count of all rooms we need to make three queries: + # 1. Fetch all counts from `event_push_summary`, discarding any stale + # rooms. + # 2. Fetch all notifications from `event_push_actions` that haven't + # been rotated yet. + # 3. Fetch all notifications from `event_push_actions` for the stale + # rooms. + # + # The "stale room" scenario generally happens when there is a new read + # receipt that hasn't yet been processed to update the + # `event_push_summary` table. When that happens we ignore the + # `event_push_summary` table for that room and calculate the count + # manually from `event_push_actions`. + + # We need to only take into account read receipts of these types. + receipt_types_clause, receipt_types_args = make_in_list_sql_clause( self.database_engine, "receipt_type", (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) - args.extend([user_id, user_id]) - - receipts_cte = f""" - WITH all_receipts AS ( - SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering - FROM receipts_linearized - LEFT JOIN events USING (room_id, event_id) - WHERE - {receipt_types_clause} - AND user_id = ? - GROUP BY room_id, thread_id - ) - """ - - receipts_joins = """ - LEFT JOIN ( - SELECT room_id, thread_id, - max_receipt_stream_ordering AS threaded_receipt_stream_ordering - FROM all_receipts - WHERE thread_id IS NOT NULL - ) AS threaded_receipts USING (room_id, thread_id) - LEFT JOIN ( - SELECT room_id, thread_id, - max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering - FROM all_receipts - WHERE thread_id IS NULL - ) AS unthreaded_receipts USING (room_id) - """ - - # First get summary counts by room / thread for the user. We use the max receipt - # stream ordering of both threaded & unthreaded receipts to compare against the - # summary table. - # - # PostgreSQL and SQLite differ in comparing scalar numerics. - if isinstance(self.database_engine, PostgresEngine): - # GREATEST ignores NULLs. - max_clause = """GREATEST( - threaded_receipt_stream_ordering, - unthreaded_receipt_stream_ordering - )""" - else: - # MAX returns NULL if any are NULL, so COALESCE to 0 first. - max_clause = """MAX( - COALESCE(threaded_receipt_stream_ordering, 0), - COALESCE(unthreaded_receipt_stream_ordering, 0) - )""" + # Step 1, fetch all counts from `event_push_summary` for the user. This + # is slightly convoluted as we also need to pull out the stream ordering + # of the most recent receipt of the user in the room (either a thread + # aware receipt or thread unaware receipt) in order to determine + # whether the row in `event_push_summary` is stale. Hence the outer + # GROUP BY and odd join condition against `receipts_linearized`. sql = f""" - {receipts_cte} - SELECT eps.room_id, eps.thread_id, notif_count - FROM event_push_summary AS eps - {receipts_joins} - WHERE user_id = ? - AND notif_count != 0 - AND ( - (last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause}) - OR last_receipt_stream_ordering = {max_clause} + SELECT room_id, notif_count, stream_ordering, thread_id, last_receipt_stream_ordering, + MAX(receipt_stream_ordering) + FROM ( + SELECT e.room_id, notif_count, e.stream_ordering, e.thread_id, last_receipt_stream_ordering, + ev.stream_ordering AS receipt_stream_ordering + FROM event_push_summary AS e + INNER JOIN local_current_membership USING (user_id, room_id) + LEFT JOIN receipts_linearized AS r ON ( + e.user_id = r.user_id + AND e.room_id = r.room_id + AND (e.thread_id = r.thread_id OR r.thread_id IS NULL) + AND {receipt_types_clause} ) + LEFT JOIN events AS ev ON (r.event_id = ev.event_id) + WHERE e.user_id = ? and notif_count > 0 + ) AS es + GROUP BY room_id, notif_count, stream_ordering, thread_id, last_receipt_stream_ordering """ - txn.execute(sql, args) - - seen_thread_ids = set() - room_to_count: Dict[str, int] = defaultdict(int) - for room_id, thread_id, notif_count in txn: - room_to_count[room_id] += notif_count - seen_thread_ids.add(thread_id) + txn.execute( + sql, + receipt_types_args + + [ + user_id, + ], + ) - # Now get any event push actions that haven't been rotated using the same OR - # join and filter by receipt and event push summary rotated up to stream ordering. - sql = f""" - {receipts_cte} - SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count - FROM event_push_actions AS epa - {receipts_joins} - WHERE user_id = ? - AND epa.notif = 1 - AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering) - AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering) - AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering) - GROUP BY epa.room_id, epa.thread_id - """ - txn.execute(sql, args) + room_to_count: Dict[str, int] = defaultdict(int) + stale_room_ids = set() + for row in txn: + room_id = row[0] + notif_count = row[1] + stream_ordering = row[2] + _thread_id = row[3] + last_receipt_stream_ordering = row[4] + receipt_stream_ordering = row[5] + + if last_receipt_stream_ordering is None: + if receipt_stream_ordering is None: + room_to_count[room_id] += notif_count + elif stream_ordering > receipt_stream_ordering: + room_to_count[room_id] += notif_count + else: + # The latest read receipt from the user is after all the rows for + # this room in `event_push_summary`. We ignore them, and + # calculate the count from `event_push_actions` in step 3. + pass + elif last_receipt_stream_ordering == receipt_stream_ordering: + room_to_count[room_id] += notif_count + else: + # The row is stale if `last_receipt_stream_ordering` is set and + # *doesn't* match the latest receipt from the user. + stale_room_ids.add(room_id) - for room_id, thread_id, notif_count in txn: - # Note: only count push actions we have valid summaries for with up to date receipt. - if thread_id not in seen_thread_ids: - continue - room_to_count[room_id] += notif_count + # Discard any stale rooms from `room_to_count`, as we will recalculate + # them in step 3. + for room_id in stale_room_ids: + room_to_count.pop(room_id, None) - thread_id_clause, thread_ids_args = make_in_list_sql_clause( - self.database_engine, "epa.thread_id", seen_thread_ids + # Step 2, basically the same query, except against `event_push_actions` + # and only fetching rows inserted since the last rotation. + rotated_upto_stream_ordering = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", ) - # Finally re-check event_push_actions for any rooms not in the summary, ignoring - # the rotated up-to position. This handles the case where a read receipt has arrived - # but not been rotated meaning the summary table is out of date, so we go back to - # the push actions table. sql = f""" - {receipts_cte} - SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count - FROM event_push_actions AS epa - {receipts_joins} - WHERE user_id = ? - AND NOT {thread_id_clause} - AND epa.notif = 1 - AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering) - AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering) - GROUP BY epa.room_id + SELECT room_id, thread_id + FROM ( + SELECT e.room_id, e.stream_ordering, e.thread_id, + ev.stream_ordering AS receipt_stream_ordering + FROM event_push_actions AS e + INNER JOIN local_current_membership USING (user_id, room_id) + LEFT JOIN receipts_linearized AS r ON ( + e.user_id = r.user_id + AND e.room_id = r.room_id + AND (e.thread_id = r.thread_id OR r.thread_id IS NULL) + AND {receipt_types_clause} + ) + LEFT JOIN events AS ev ON (r.event_id = ev.event_id) + WHERE e.user_id = ? and notif > 0 + AND e.stream_ordering > ? + ) AS es + GROUP BY room_id, stream_ordering, thread_id + HAVING stream_ordering > COALESCE(MAX(receipt_stream_ordering), 0) """ - args.extend(thread_ids_args) - txn.execute(sql, args) + txn.execute( + sql, + receipt_types_args + [user_id, rotated_upto_stream_ordering], + ) + for room_id, _thread_id in txn: + # Again, we ignore any stale rooms. + if room_id not in stale_room_ids: + # For event push actions it is one notification per row. + room_to_count[room_id] += 1 + + # Step 3, if we have stale rooms then we need to recalculate the counts + # from `event_push_actions`. Again, this is basically the same query as + # above except without a lower bound on stream ordering and only against + # a specific set of rooms. + if stale_room_ids: + room_id_clause, room_id_args = make_in_list_sql_clause( + self.database_engine, + "e.room_id", + stale_room_ids, + ) - for room_id, notif_count in txn: - room_to_count[room_id] += notif_count + sql = f""" + SELECT room_id, thread_id + FROM ( + SELECT e.room_id, e.stream_ordering, e.thread_id, + ev.stream_ordering AS receipt_stream_ordering + FROM event_push_actions AS e + INNER JOIN local_current_membership USING (user_id, room_id) + LEFT JOIN receipts_linearized AS r ON ( + e.user_id = r.user_id + AND e.room_id = r.room_id + AND (e.thread_id = r.thread_id OR r.thread_id IS NULL) + AND {receipt_types_clause} + ) + LEFT JOIN events AS ev ON (r.event_id = ev.event_id) + WHERE e.user_id = ? and notif > 0 + AND {room_id_clause} + ) AS es + GROUP BY room_id, stream_ordering, thread_id + HAVING stream_ordering > COALESCE(MAX(receipt_stream_ordering), 0) + """ + txn.execute( + sql, + receipt_types_args + [user_id] + room_id_args, + ) + for room_id, _ in txn: + room_to_count[room_id] += 1 return room_to_count diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 4700e74ad2..8006046453 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -24,13 +24,17 @@ from typing import ( Any, Collection, Dict, + FrozenSet, Iterable, List, Mapping, Optional, Set, Tuple, + TypeVar, + Union, cast, + overload, ) import attr @@ -52,7 +56,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.types import JsonDict, JsonMapping, StateMap +from synapse.types import JsonDict, JsonMapping, StateKey, StateMap, StrCollection from synapse.types.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -64,6 +68,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +_T = TypeVar("_T") + MAX_STATE_DELTA_HOPS = 100 @@ -318,6 +324,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): "get_partial_current_state_ids", _get_current_state_ids_txn ) + async def check_if_events_in_current_state( + self, event_ids: StrCollection + ) -> FrozenSet[str]: + """Checks and returns which of the given events is part of the current state.""" + rows = await self.db_pool.simple_select_many_batch( + table="current_state_events", + column="event_id", + iterable=event_ids, + retcols=("event_id",), + desc="check_if_events_in_current_state", + ) + + return frozenset(event_id for event_id, in rows) + # FIXME: how should this be cached? @cancellable async def get_partial_filtered_current_state_ids( @@ -349,7 +369,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def _get_filtered_current_state_ids_txn( txn: LoggingTransaction, ) -> StateMap[str]: - results = {} + results = StateMapWrapper(state_filter=state_filter or StateFilter.all()) + sql = """ SELECT type, state_key, event_id FROM current_state_events WHERE room_id = ? @@ -726,3 +747,41 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): hs: "HomeServer", ): super().__init__(database, db_conn, hs) + + +@attr.s(auto_attribs=True, slots=True) +class StateMapWrapper(Dict[StateKey, str]): + """A wrapper around a StateMap[str] to ensure that we only query for items + that were not filtered out. + + This is to help prevent bugs where we filter out state but other bits of the + code expect the state to be there. + """ + + state_filter: StateFilter + + def __getitem__(self, key: StateKey) -> str: + if key not in self.state_filter: + raise Exception("State map was filtered and doesn't include: %s", key) + return super().__getitem__(key) + + @overload + def get(self, key: Tuple[str, str]) -> Optional[str]: + ... + + @overload + def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]: + ... + + def get( + self, key: StateKey, default: Union[str, _T, None] = None + ) -> Union[str, _T, None]: + if key not in self.state_filter: + raise Exception("State map was filtered and doesn't include: %s", key) + return super().get(key, default) + + def __contains__(self, key: Any) -> bool: + if key not in self.state_filter: + raise Exception("State map was filtered and doesn't include: %s", key) + + return super().__contains__(key) diff --git a/synapse/types/state.py b/synapse/types/state.py
index 5ca3c94bce..53662372af 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py
@@ -20,6 +20,7 @@ import logging from typing import ( TYPE_CHECKING, + Any, Callable, Collection, Dict, @@ -584,6 +585,29 @@ class StateFilter: # local users only return False + def __contains__(self, key: Any) -> bool: + if not isinstance(key, tuple) or len(key) != 2: + raise TypeError( + f"'in StateFilter' requires (str, str) as left operand, not {type(key).__name__}" + ) + + typ, state_key = key + + if not isinstance(typ, str) or not isinstance(state_key, str): + raise TypeError( + f"'in StateFilter' requires (str, str) as left operand, not ({type(typ).__name__}, {type(state_key).__name__})" + ) + + if typ in self.types: + state_keys = self.types[typ] + if state_keys is None or state_key in state_keys: + return True + + elif self.include_others: + return True + + return False + _ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True) _ALL_NON_MEMBER_STATE_FILTER = StateFilter(