diff options
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r-- | synapse/storage/databases/main/account_data.py | 164 | ||||
-rw-r--r-- | synapse/storage/databases/main/appservice.py | 24 | ||||
-rw-r--r-- | synapse/storage/databases/main/cache.py | 16 | ||||
-rw-r--r-- | synapse/storage/databases/main/deviceinbox.py | 279 | ||||
-rw-r--r-- | synapse/storage/databases/main/devices.py | 22 | ||||
-rw-r--r-- | synapse/storage/databases/main/event_federation.py | 325 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 19 | ||||
-rw-r--r-- | synapse/storage/databases/main/push_rule.py | 20 | ||||
-rw-r--r-- | synapse/storage/databases/main/relations.py | 427 |
9 files changed, 916 insertions, 380 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 5bfa408f74..52146aacc8 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -106,6 +106,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) "AccountDataAndTagsChangeCache", account_max ) + self.db_pool.updates.register_background_update_handler( + "delete_account_data_for_deactivated_users", + self._delete_account_data_for_deactivated_users, + ) + def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream @@ -549,72 +554,121 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def purge_account_data_for_user(self, user_id: str) -> None: """ - Removes the account data for a user. + Removes ALL the account data for a user. + Intended to be used upon user deactivation. - This is intended to be used upon user deactivation and also removes any - derived information from account data (e.g. push rules and ignored users). + Also purges the user from the ignored_users cache table + and the push_rules cache tables. + """ - Args: - user_id: The user ID to remove data for. + await self.db_pool.runInteraction( + "purge_account_data_for_user_txn", + self._purge_account_data_for_user_txn, + user_id, + ) + + def _purge_account_data_for_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> None: """ + See `purge_account_data_for_user`. + """ + # Purge from the primary account_data tables. + self.db_pool.simple_delete_txn( + txn, table="account_data", keyvalues={"user_id": user_id} + ) - def purge_account_data_for_user_txn(txn: LoggingTransaction) -> None: - # Purge from the primary account_data tables. - self.db_pool.simple_delete_txn( - txn, table="account_data", keyvalues={"user_id": user_id} - ) + self.db_pool.simple_delete_txn( + txn, table="room_account_data", keyvalues={"user_id": user_id} + ) - self.db_pool.simple_delete_txn( - txn, table="room_account_data", keyvalues={"user_id": user_id} - ) + # Purge from ignored_users where this user is the ignorer. + # N.B. We don't purge where this user is the ignoree, because that + # interferes with other users' account data. + # It's also not this user's data to delete! + self.db_pool.simple_delete_txn( + txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id} + ) - # Purge from ignored_users where this user is the ignorer. - # N.B. We don't purge where this user is the ignoree, because that - # interferes with other users' account data. - # It's also not this user's data to delete! - self.db_pool.simple_delete_txn( - txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id} - ) + # Remove the push rules + self.db_pool.simple_delete_txn( + txn, table="push_rules", keyvalues={"user_name": user_id} + ) + self.db_pool.simple_delete_txn( + txn, table="push_rules_enable", keyvalues={"user_name": user_id} + ) + self.db_pool.simple_delete_txn( + txn, table="push_rules_stream", keyvalues={"user_id": user_id} + ) - # Remove the push rules - self.db_pool.simple_delete_txn( - txn, table="push_rules", keyvalues={"user_name": user_id} - ) - self.db_pool.simple_delete_txn( - txn, table="push_rules_enable", keyvalues={"user_name": user_id} - ) - self.db_pool.simple_delete_txn( - txn, table="push_rules_stream", keyvalues={"user_id": user_id} - ) + # Invalidate caches as appropriate + self._invalidate_cache_and_stream( + txn, self.get_account_data_for_room_and_type, (user_id,) + ) + self._invalidate_cache_and_stream( + txn, self.get_account_data_for_user, (user_id,) + ) + self._invalidate_cache_and_stream( + txn, self.get_global_account_data_by_type_for_user, (user_id,) + ) + self._invalidate_cache_and_stream( + txn, self.get_account_data_for_room, (user_id,) + ) + self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,)) + self._invalidate_cache_and_stream( + txn, self.get_push_rules_enabled_for_user, (user_id,) + ) + # This user might be contained in the ignored_by cache for other users, + # so we have to invalidate it all. + self._invalidate_all_cache_and_stream(txn, self.ignored_by) - # Invalidate caches as appropriate - self._invalidate_cache_and_stream( - txn, self.get_account_data_for_room_and_type, (user_id,) - ) - self._invalidate_cache_and_stream( - txn, self.get_account_data_for_user, (user_id,) - ) - self._invalidate_cache_and_stream( - txn, self.get_global_account_data_by_type_for_user, (user_id,) - ) - self._invalidate_cache_and_stream( - txn, self.get_account_data_for_room, (user_id,) - ) - self._invalidate_cache_and_stream( - txn, self.get_push_rules_for_user, (user_id,) - ) - self._invalidate_cache_and_stream( - txn, self.get_push_rules_enabled_for_user, (user_id,) - ) - # This user might be contained in the ignored_by cache for other users, - # so we have to invalidate it all. - self._invalidate_all_cache_and_stream(txn, self.ignored_by) + async def _delete_account_data_for_deactivated_users( + self, progress: dict, batch_size: int + ) -> int: + """ + Retroactively purges account data for users that have already been deactivated. + Gets run as a background update caused by a schema delta. + """ - await self.db_pool.runInteraction( - "purge_account_data_for_user_txn", - purge_account_data_for_user_txn, + last_user: str = progress.get("last_user", "") + + def _delete_account_data_for_deactivated_users_txn( + txn: LoggingTransaction, + ) -> int: + sql = """ + SELECT name FROM users + WHERE deactivated = ? and name > ? + ORDER BY name ASC + LIMIT ? + """ + + txn.execute(sql, (1, last_user, batch_size)) + users = [row[0] for row in txn] + + for user in users: + self._purge_account_data_for_user_txn(txn, user_id=user) + + if users: + self.db_pool.updates._background_update_progress_txn( + txn, + "delete_account_data_for_deactivated_users", + {"last_user": users[-1]}, + ) + + return len(users) + + number_deleted = await self.db_pool.runInteraction( + "_delete_account_data_for_deactivated_users", + _delete_account_data_for_deactivated_users_txn, ) + if number_deleted < batch_size: + await self.db_pool.updates._end_background_update( + "delete_account_data_for_deactivated_users" + ) + + return number_deleted + class AccountDataStore(AccountDataWorkerStore): pass diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 2bb5288431..304814af5d 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -198,6 +198,7 @@ class ApplicationServiceTransactionWorkerStore( service: ApplicationService, events: List[EventBase], ephemeral: List[JsonDict], + to_device_messages: List[JsonDict], ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -207,6 +208,7 @@ class ApplicationServiceTransactionWorkerStore( service: The service who the transaction is for. events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. + to_device_messages: A list of to-device messages to put in the transaction. Returns: A new transaction. @@ -237,7 +239,11 @@ class ApplicationServiceTransactionWorkerStore( (service.id, new_txn_id, event_ids), ) return AppServiceTransaction( - service=service, id=new_txn_id, events=events, ephemeral=ephemeral + service=service, + id=new_txn_id, + events=events, + ephemeral=ephemeral, + to_device_messages=to_device_messages, ) return await self.db_pool.runInteraction( @@ -330,7 +336,11 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) return AppServiceTransaction( - service=service, id=entry["txn_id"], events=events, ephemeral=[] + service=service, + id=entry["txn_id"], + events=events, + ephemeral=[], + to_device_messages=[], ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: @@ -391,7 +401,7 @@ class ApplicationServiceTransactionWorkerStore( async def get_type_stream_id_for_appservice( self, service: ApplicationService, type: str ) -> int: - if type not in ("read_receipt", "presence"): + if type not in ("read_receipt", "presence", "to_device"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (type,) @@ -415,16 +425,16 @@ class ApplicationServiceTransactionWorkerStore( "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn ) - async def set_type_stream_id_for_appservice( + async def set_appservice_stream_type_pos( self, service: ApplicationService, stream_type: str, pos: Optional[int] ) -> None: - if stream_type not in ("read_receipt", "presence"): + if stream_type not in ("read_receipt", "presence", "to_device"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (stream_type,) ) - def set_type_stream_id_for_appservice_txn(txn): + def set_appservice_stream_type_pos_txn(txn): stream_id_type = "%s_stream_id" % stream_type txn.execute( "UPDATE application_services_state SET %s = ? WHERE as_id=?" @@ -433,7 +443,7 @@ class ApplicationServiceTransactionWorkerStore( ) await self.db_pool.runInteraction( - "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn + "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 0024348067..c428dd5596 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -15,7 +15,7 @@ import itertools import logging -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple from synapse.api.constants import EventTypes from synapse.replication.tcp.streams import BackfillStream, CachesStream @@ -25,7 +25,11 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, ) from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter @@ -236,7 +240,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn.call_after(cache_func.invalidate_all) self._send_invalidation_to_replication(txn, cache_func.__name__, None) - def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): + def _invalidate_state_caches_and_stream( + self, txn: LoggingTransaction, room_id: str, members_changed: Collection[str] + ) -> None: """Special case invalidation of caches based on current state. We special case this so that we can batch the cache invalidations into a @@ -244,8 +250,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): Args: txn - room_id (str): Room where state changed - members_changed (iterable[str]): The user_ids of members that have changed + room_id: Room where state changed + members_changed: The user_ids of members that have changed """ txn.call_after(self._invalidate_state_caches, room_id, members_changed) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 4eca97189b..1392363de1 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -24,6 +24,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( @@ -136,63 +137,263 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() - async def get_new_messages_for_device( + async def get_messages_for_user_devices( + self, + user_ids: Collection[str], + from_stream_id: int, + to_stream_id: int, + ) -> Dict[Tuple[str, str], List[JsonDict]]: + """ + Retrieve to-device messages for a given set of users. + + Only to-device messages with stream ids between the given boundaries + (from < X <= to) are returned. + + Args: + user_ids: The users to retrieve to-device messages for. + from_stream_id: The lower boundary of stream id to filter with (exclusive). + to_stream_id: The upper boundary of stream id to filter with (inclusive). + + Returns: + A dictionary of (user id, device id) -> list of to-device messages. + """ + # We expect the stream ID returned by _get_device_messages to always + # be to_stream_id. So, no need to return it from this function. + ( + user_id_device_id_to_messages, + last_processed_stream_id, + ) = await self._get_device_messages( + user_ids=user_ids, + from_stream_id=from_stream_id, + to_stream_id=to_stream_id, + ) + + assert ( + last_processed_stream_id == to_stream_id + ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`" + + return user_id_device_id_to_messages + + async def get_messages_for_device( self, user_id: str, - device_id: Optional[str], - last_stream_id: int, - current_stream_id: int, + device_id: str, + from_stream_id: int, + to_stream_id: int, limit: int = 100, - ) -> Tuple[List[dict], int]: + ) -> Tuple[List[JsonDict], int]: """ + Retrieve to-device messages for a single user device. + + Only to-device messages with stream ids between the given boundaries + (from < X <= to) are returned. + Args: - user_id: The recipient user_id. - device_id: The recipient device_id. - last_stream_id: The last stream ID checked. - current_stream_id: The current position of the to device - message stream. - limit: The maximum number of messages to retrieve. + user_id: The ID of the user to retrieve messages for. + device_id: The ID of the device to retrieve to-device messages for. + from_stream_id: The lower boundary of stream id to filter with (exclusive). + to_stream_id: The upper boundary of stream id to filter with (inclusive). + limit: A limit on the number of to-device messages returned. Returns: A tuple containing: - * A list of messages for the device. - * The max stream token of these messages. There may be more to retrieve - if the given limit was reached. + * A list of to-device messages within the given stream id range intended for + the given user / device combo. + * The last-processed stream ID. Subsequent calls of this function with the + same device should pass this value as 'from_stream_id'. """ - has_changed = self._device_inbox_stream_cache.has_entity_changed( - user_id, last_stream_id + ( + user_id_device_id_to_messages, + last_processed_stream_id, + ) = await self._get_device_messages( + user_ids=[user_id], + device_id=device_id, + from_stream_id=from_stream_id, + to_stream_id=to_stream_id, + limit=limit, ) - if not has_changed: - return [], current_stream_id - def get_new_messages_for_device_txn(txn): - sql = ( - "SELECT stream_id, message_json FROM device_inbox" - " WHERE user_id = ? AND device_id = ?" - " AND ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - " LIMIT ?" + if not user_id_device_id_to_messages: + # There were no messages! + return [], to_stream_id + + # Extract the messages, no need to return the user and device ID again + to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), []) + + return to_device_messages, last_processed_stream_id + + async def _get_device_messages( + self, + user_ids: Collection[str], + from_stream_id: int, + to_stream_id: int, + device_id: Optional[str] = None, + limit: Optional[int] = None, + ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: + """ + Retrieve pending to-device messages for a collection of user devices. + + Only to-device messages with stream ids between the given boundaries + (from < X <= to) are returned. + + Note that a stream ID can be shared by multiple copies of the same message with + different recipient devices. Stream IDs are only unique in the context of a single + user ID / device ID pair. Thus, applying a limit (of messages to return) when working + with a sliding window of stream IDs is only possible when querying messages of a + single user device. + + Finally, note that device IDs are not unique across users. + + Args: + user_ids: The user IDs to filter device messages by. + from_stream_id: The lower boundary of stream id to filter with (exclusive). + to_stream_id: The upper boundary of stream id to filter with (inclusive). + device_id: A device ID to query to-device messages for. If not provided, to-device + messages from all device IDs for the given user IDs will be queried. May not be + provided if `user_ids` contains more than one entry. + limit: The maximum number of to-device messages to return. Can only be used when + passing a single user ID / device ID tuple. + + Returns: + A tuple containing: + * A dict of (user_id, device_id) -> list of to-device messages + * The last-processed stream ID. If this is less than `to_stream_id`, then + there may be more messages to retrieve. If `limit` is not set, then this + is always equal to 'to_stream_id'. + """ + if not user_ids: + logger.warning("No users provided upon querying for device IDs") + return {}, to_stream_id + + # Prevent a query for one user's device also retrieving another user's device with + # the same device ID (device IDs are not unique across users). + if len(user_ids) > 1 and device_id is not None: + raise AssertionError( + "Programming error: 'device_id' cannot be supplied to " + "_get_device_messages when >1 user_id has been provided" ) - txn.execute( - sql, (user_id, device_id, last_stream_id, current_stream_id, limit) + + # A limit can only be applied when querying for a single user ID / device ID tuple. + # See the docstring of this function for more details. + if limit is not None and device_id is None: + raise AssertionError( + "Programming error: _get_device_messages was passed 'limit' " + "without a specific user_id/device_id" ) - messages = [] - stream_pos = current_stream_id + user_ids_to_query: Set[str] = set() + device_ids_to_query: Set[str] = set() + + # Note that a device ID could be an empty str + if device_id is not None: + # If a device ID was passed, use it to filter results. + # Otherwise, device IDs will be derived from the given collection of user IDs. + device_ids_to_query.add(device_id) + + # Determine which users have devices with pending messages + for user_id in user_ids: + if self._device_inbox_stream_cache.has_entity_changed( + user_id, from_stream_id + ): + # This user has new messages sent to them. Query messages for them + user_ids_to_query.add(user_id) + + def get_device_messages_txn(txn: LoggingTransaction): + # Build a query to select messages from any of the given devices that + # are between the given stream id bounds. + + # If a list of device IDs was not provided, retrieve all devices IDs + # for the given users. We explicitly do not query hidden devices, as + # hidden devices should not receive to-device messages. + # Note that this is more efficient than just dropping `device_id` from the query, + # since device_inbox has an index on `(user_id, device_id, stream_id)` + if not device_ids_to_query: + user_device_dicts = self.db_pool.simple_select_many_txn( + txn, + table="devices", + column="user_id", + iterable=user_ids_to_query, + keyvalues={"user_id": user_id, "hidden": False}, + retcols=("device_id",), + ) - for row in txn: - stream_pos = row[0] - messages.append(db_to_json(row[1])) + device_ids_to_query.update( + {row["device_id"] for row in user_device_dicts} + ) - # If the limit was not reached we know that there's no more data for this - # user/device pair up to current_stream_id. - if len(messages) < limit: - stream_pos = current_stream_id + if not device_ids_to_query: + # We've ended up with no devices to query. + return {}, to_stream_id - return messages, stream_pos + # We include both user IDs and device IDs in this query, as we have an index + # (device_inbox_user_stream_id) for them. + user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids_to_query + ) + ( + device_id_many_clause_sql, + device_id_many_clause_args, + ) = make_in_list_sql_clause( + self.database_engine, "device_id", device_ids_to_query + ) + + sql = f""" + SELECT stream_id, user_id, device_id, message_json FROM device_inbox + WHERE {user_id_many_clause_sql} + AND {device_id_many_clause_sql} + AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + sql_args = ( + *user_id_many_clause_args, + *device_id_many_clause_args, + from_stream_id, + to_stream_id, + ) + + # If a limit was provided, limit the data retrieved from the database + if limit is not None: + sql += "LIMIT ?" + sql_args += (limit,) + + txn.execute(sql, sql_args) + + # Create and fill a dictionary of (user ID, device ID) -> list of messages + # intended for each device. + last_processed_stream_pos = to_stream_id + recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {} + rowcount = 0 + for row in txn: + rowcount += 1 + + last_processed_stream_pos = row[0] + recipient_user_id = row[1] + recipient_device_id = row[2] + message_dict = db_to_json(row[3]) + + # Store the device details + recipient_device_to_messages.setdefault( + (recipient_user_id, recipient_device_id), [] + ).append(message_dict) + + if limit is not None and rowcount == limit: + # We ended up bumping up against the message limit. There may be more messages + # to retrieve. Return what we have, as well as the last stream position that + # was processed. + # + # The caller is expected to set this as the lower (exclusive) bound + # for the next query of this device. + return recipient_device_to_messages, last_processed_stream_pos + + # The limit was not reached, thus we know that recipient_device_to_messages + # contains all to-device messages for the given device and stream id range. + # + # We return to_stream_id, which the caller should then provide as the lower + # (exclusive) bound on the next query of this device. + return recipient_device_to_messages, to_stream_id return await self.db_pool.runInteraction( - "get_new_messages_for_device", get_new_messages_for_device_txn + "get_device_messages", get_device_messages_txn ) @trace diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index b2a5cd9a65..8d845fe951 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1496,13 +1496,23 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def add_device_change_to_streams( - self, user_id: str, device_ids: Collection[str], hosts: List[str] - ) -> int: + self, user_id: str, device_ids: Collection[str], hosts: Collection[str] + ) -> Optional[int]: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. + + Args: + user_id: The ID of the user whose device changed. + device_ids: The IDs of any changed devices. If empty, this function will + return None. + hosts: The remote destinations that should be notified of the change. + + Returns: + The maximum stream ID of device list updates that were added to the database, or + None if no updates were added. """ if not device_ids: - return + return None async with self._device_list_id_gen.get_next_mult( len(device_ids) @@ -1573,11 +1583,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn: LoggingTransaction, user_id: str, - device_ids: Collection[str], - hosts: List[str], + device_ids: Iterable[str], + hosts: Collection[str], stream_ids: List[str], context: Dict[str, str], - ): + ) -> None: for host in hosts: txn.call_after( self._device_list_federation_stream_cache.entity_has_changed, diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ca71f073fc..277e6422eb 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -16,9 +16,10 @@ import logging from queue import Empty, PriorityQueue from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +import attr from prometheus_client import Counter, Gauge -from synapse.api.constants import MAX_DEPTH +from synapse.api.constants import MAX_DEPTH, EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.events import EventBase, make_event_from_dict @@ -60,6 +61,15 @@ pdus_pruned_from_federation_queue = Counter( logger = logging.getLogger(__name__) +# All the info we need while iterating the DAG while backfilling +@attr.s(frozen=True, slots=True, auto_attribs=True) +class BackfillQueueNavigationItem: + depth: int + stream_ordering: int + event_id: str + type: str + + class _NoChainCoverIndex(Exception): def __init__(self, room_id: str): super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) @@ -74,6 +84,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ): super().__init__(database, db_conn, hs) + self.hs = hs + if hs.config.worker.run_background_tasks: hs.get_clock().looping_call( self._delete_old_forward_extrem_cache, 60 * 60 * 1000 @@ -109,7 +121,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas room_id: str, event_ids: Collection[str], include_given: bool = False, - ) -> List[str]: + ) -> Set[str]: """Get auth events for given event_ids. The events *must* be state events. Args: @@ -118,7 +130,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas include_given: include the given events in result Returns: - list of event_ids + set of event_ids """ # Check if we have indexed the room so we can use the chain cover @@ -147,7 +159,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas def _get_auth_chain_ids_using_cover_index_txn( self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool - ) -> List[str]: + ) -> Set[str]: """Calculates the auth chain IDs using the chain index.""" # First we look up the chain ID/sequence numbers for the given events. @@ -260,11 +272,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (chain_id, max_no)) results.update(r for r, in txn) - return list(results) + return results def _get_auth_chain_ids_txn( self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool - ) -> List[str]: + ) -> Set[str]: """Calculates the auth chain IDs. This is used when we don't have a cover index for the room. @@ -319,7 +331,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas front = new_front results.update(front) - return list(results) + return results async def get_auth_chain_difference( self, room_id: str, state_sets: List[Set[str]] @@ -737,7 +749,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas room_id, ) - async def get_insertion_event_backwards_extremities_in_room( + async def get_insertion_event_backward_extremities_in_room( self, room_id ) -> Dict[str, int]: """Get the insertion events we know about that we haven't backfilled yet. @@ -754,7 +766,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas Map from event_id to depth """ - def get_insertion_event_backwards_extremities_in_room_txn(txn, room_id): + def get_insertion_event_backward_extremities_in_room_txn(txn, room_id): sql = """ SELECT b.event_id, MAX(e.depth) FROM insertion_events as i /* We only want insertion events that are also marked as backwards extremities */ @@ -770,8 +782,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return dict(txn) return await self.db_pool.runInteraction( - "get_insertion_event_backwards_extremities_in_room", - get_insertion_event_backwards_extremities_in_room_txn, + "get_insertion_event_backward_extremities_in_room", + get_insertion_event_backward_extremities_in_room_txn, room_id, ) @@ -997,143 +1009,242 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - async def get_backfill_events(self, room_id: str, event_list: list, limit: int): - """Get a list of Events for a given topic that occurred before (and - including) the events in event_list. Return a list of max size `limit` + def _get_connected_batch_event_backfill_results_txn( + self, txn: LoggingTransaction, insertion_event_id: str, limit: int + ) -> List[BackfillQueueNavigationItem]: + """ + Find any batch connections of a given insertion event. + A batch event points at a insertion event via: + batch_event.content[MSC2716_BATCH_ID] -> insertion_event.content[MSC2716_NEXT_BATCH_ID] Args: - room_id - event_list - limit + txn: The database transaction to use + insertion_event_id: The event ID to navigate from. We will find + batch events that point back at this insertion event. + limit: Max number of event ID's to query for and return + + Returns: + List of batch events that the backfill queue can process + """ + batch_connection_query = """ + SELECT e.depth, e.stream_ordering, c.event_id, e.type FROM insertion_events AS i + /* Find the batch that connects to the given insertion event */ + INNER JOIN batch_events AS c + ON i.next_batch_id = c.batch_id + /* Get the depth of the batch start event from the events table */ + INNER JOIN events AS e USING (event_id) + /* Find an insertion event which matches the given event_id */ + WHERE i.event_id = ? + LIMIT ? """ - event_ids = await self.db_pool.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - events = await self.get_events_as_list(event_ids) - return sorted(events, key=lambda e: -e.depth) - def _get_backfill_events(self, txn, room_id, event_list, limit): - logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) + # Find any batch connections for the given insertion event + txn.execute( + batch_connection_query, + (insertion_event_id, limit), + ) + return [ + BackfillQueueNavigationItem( + depth=row[0], + stream_ordering=row[1], + event_id=row[2], + type=row[3], + ) + for row in txn + ] - event_results = set() + def _get_connected_prev_event_backfill_results_txn( + self, txn: LoggingTransaction, event_id: str, limit: int + ) -> List[BackfillQueueNavigationItem]: + """ + Find any events connected by prev_event the specified event_id. - # We want to make sure that we do a breadth-first, "depth" ordered - # search. + Args: + txn: The database transaction to use + event_id: The event ID to navigate from + limit: Max number of event ID's to query for and return + Returns: + List of prev events that the backfill queue can process + """ # Look for the prev_event_id connected to the given event_id - query = """ - SELECT depth, prev_event_id FROM event_edges - /* Get the depth of the prev_event_id from the events table */ + connected_prev_event_query = """ + SELECT depth, stream_ordering, prev_event_id, events.type FROM event_edges + /* Get the depth and stream_ordering of the prev_event_id from the events table */ INNER JOIN events ON prev_event_id = events.event_id - /* Find an event which matches the given event_id */ + /* Look for an edge which matches the given event_id */ WHERE event_edges.event_id = ? AND event_edges.is_state = ? + /* Because we can have many events at the same depth, + * we want to also tie-break and sort on stream_ordering */ + ORDER BY depth DESC, stream_ordering DESC LIMIT ? """ - # Look for the "insertion" events connected to the given event_id - connected_insertion_event_query = """ - SELECT e.depth, i.event_id FROM insertion_event_edges AS i - /* Get the depth of the insertion event from the events table */ - INNER JOIN events AS e USING (event_id) - /* Find an insertion event which points via prev_events to the given event_id */ - WHERE i.insertion_prev_event_id = ? - LIMIT ? + txn.execute( + connected_prev_event_query, + (event_id, False, limit), + ) + return [ + BackfillQueueNavigationItem( + depth=row[0], + stream_ordering=row[1], + event_id=row[2], + type=row[3], + ) + for row in txn + ] + + async def get_backfill_events( + self, room_id: str, seed_event_id_list: list, limit: int + ): + """Get a list of Events for a given topic that occurred before (and + including) the events in seed_event_id_list. Return a list of max size `limit` + + Args: + room_id + seed_event_id_list + limit """ + event_ids = await self.db_pool.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + seed_event_id_list, + limit, + ) + events = await self.get_events_as_list(event_ids) + return sorted( + events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering) + ) - # Find any batch connections of a given insertion event - batch_connection_query = """ - SELECT e.depth, c.event_id FROM insertion_events AS i - /* Find the batch that connects to the given insertion event */ - INNER JOIN batch_events AS c - ON i.next_batch_id = c.batch_id - /* Get the depth of the batch start event from the events table */ - INNER JOIN events AS e USING (event_id) - /* Find an insertion event which matches the given event_id */ - WHERE i.event_id = ? - LIMIT ? + def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit): + """ + We want to make sure that we do a breadth-first, "depth" ordered search. + We also handle navigating historical branches of history connected by + insertion and batch events. """ + logger.debug( + "_get_backfill_events(room_id=%s): seeding backfill with seed_event_id_list=%s limit=%s", + room_id, + seed_event_id_list, + limit, + ) + + event_id_results = set() # In a PriorityQueue, the lowest valued entries are retrieved first. - # We're using depth as the priority in the queue. - # Depth is lowest at the oldest-in-time message and highest and - # newest-in-time message. We add events to the queue with a negative depth so that - # we process the newest-in-time messages first going backwards in time. + # We're using depth as the priority in the queue and tie-break based on + # stream_ordering. Depth is lowest at the oldest-in-time message and + # highest and newest-in-time message. We add events to the queue with a + # negative depth so that we process the newest-in-time messages first + # going backwards in time. stream_ordering follows the same pattern. queue = PriorityQueue() - for event_id in event_list: - depth = self.db_pool.simple_select_one_onecol_txn( + for seed_event_id in seed_event_id_list: + event_lookup_result = self.db_pool.simple_select_one_txn( txn, table="events", - keyvalues={"event_id": event_id, "room_id": room_id}, - retcol="depth", + keyvalues={"event_id": seed_event_id, "room_id": room_id}, + retcols=( + "type", + "depth", + "stream_ordering", + ), allow_none=True, ) - if depth: - queue.put((-depth, event_id)) + if event_lookup_result is not None: + logger.debug( + "_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s", + room_id, + seed_event_id, + event_lookup_result["depth"], + event_lookup_result["stream_ordering"], + event_lookup_result["type"], + ) + + if event_lookup_result["depth"]: + queue.put( + ( + -event_lookup_result["depth"], + -event_lookup_result["stream_ordering"], + seed_event_id, + event_lookup_result["type"], + ) + ) - while not queue.empty() and len(event_results) < limit: + while not queue.empty() and len(event_id_results) < limit: try: - _, event_id = queue.get_nowait() + _, _, event_id, event_type = queue.get_nowait() except Empty: break - if event_id in event_results: + if event_id in event_id_results: continue - event_results.add(event_id) + event_id_results.add(event_id) # Try and find any potential historical batches of message history. - # - # First we look for an insertion event connected to the current - # event (by prev_event). If we find any, we need to go and try to - # find any batch events connected to the insertion event (by - # batch_id). If we find any, we'll add them to the queue and - # navigate up the DAG like normal in the next iteration of the loop. - txn.execute( - connected_insertion_event_query, (event_id, limit - len(event_results)) - ) - connected_insertion_event_id_results = txn.fetchall() - logger.debug( - "_get_backfill_events: connected_insertion_event_query %s", - connected_insertion_event_id_results, - ) - for row in connected_insertion_event_id_results: - connected_insertion_event_depth = row[0] - connected_insertion_event = row[1] - queue.put((-connected_insertion_event_depth, connected_insertion_event)) + if self.hs.config.experimental.msc2716_enabled: + # We need to go and try to find any batch events connected + # to a given insertion event (by batch_id). If we find any, we'll + # add them to the queue and navigate up the DAG like normal in the + # next iteration of the loop. + if event_type == EventTypes.MSC2716_INSERTION: + # Find any batch connections for the given insertion event + connected_batch_event_backfill_results = ( + self._get_connected_batch_event_backfill_results_txn( + txn, event_id, limit - len(event_id_results) + ) + ) + logger.debug( + "_get_backfill_events(room_id=%s): connected_batch_event_backfill_results=%s", + room_id, + connected_batch_event_backfill_results, + ) + for ( + connected_batch_event_backfill_item + ) in connected_batch_event_backfill_results: + if ( + connected_batch_event_backfill_item.event_id + not in event_id_results + ): + queue.put( + ( + -connected_batch_event_backfill_item.depth, + -connected_batch_event_backfill_item.stream_ordering, + connected_batch_event_backfill_item.event_id, + connected_batch_event_backfill_item.type, + ) + ) - # Find any batch connections for the given insertion event - txn.execute( - batch_connection_query, - (connected_insertion_event, limit - len(event_results)), - ) - batch_start_event_id_results = txn.fetchall() - logger.debug( - "_get_backfill_events: batch_start_event_id_results %s", - batch_start_event_id_results, + # Now we just look up the DAG by prev_events as normal + connected_prev_event_backfill_results = ( + self._get_connected_prev_event_backfill_results_txn( + txn, event_id, limit - len(event_id_results) ) - for row in batch_start_event_id_results: - if row[1] not in event_results: - queue.put((-row[0], row[1])) - - txn.execute(query, (event_id, False, limit - len(event_results))) - prev_event_id_results = txn.fetchall() + ) logger.debug( - "_get_backfill_events: prev_event_ids %s", prev_event_id_results + "_get_backfill_events(room_id=%s): connected_prev_event_backfill_results=%s", + room_id, + connected_prev_event_backfill_results, ) + for ( + connected_prev_event_backfill_item + ) in connected_prev_event_backfill_results: + if connected_prev_event_backfill_item.event_id not in event_id_results: + queue.put( + ( + -connected_prev_event_backfill_item.depth, + -connected_prev_event_backfill_item.stream_ordering, + connected_prev_event_backfill_item.event_id, + connected_prev_event_backfill_item.type, + ) + ) - for row in prev_event_id_results: - if row[1] not in event_results: - queue.put((-row[0], row[1])) - - return event_results + return event_id_results async def get_missing_events(self, room_id, earliest_events, latest_events, limit): ids = await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b7554154ac..5246fccad5 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1801,9 +1801,7 @@ class PersistEventsStore: ) if rel_type == RelationTypes.REPLACE: - txn.call_after( - self.store.get_applicable_edit.invalidate, (parent_id, event.room_id) - ) + txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) if rel_type == RelationTypes.THREAD: txn.call_after( @@ -1814,7 +1812,7 @@ class PersistEventsStore: # potentially error-prone) so it is always invalidated. txn.call_after( self.store.get_thread_participated.invalidate, - (parent_id, event.room_id, event.sender), + (parent_id, event.sender), ) def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): @@ -2215,9 +2213,14 @@ class PersistEventsStore: " SELECT 1 FROM event_backward_extremities" " WHERE event_id = ? AND room_id = ?" " )" + # 1. Don't add an event as a extremity again if we already persisted it + # as a non-outlier. + # 2. Don't add an outlier as an extremity if it has no prev_events " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" + " SELECT 1 FROM events" + " LEFT JOIN event_edges edge" + " ON edge.event_id = events.event_id" + " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = ? OR edge.event_id IS NULL)" " )" ) @@ -2243,6 +2246,10 @@ class PersistEventsStore: (ev.event_id, ev.room_id) for ev in events if not ev.internal_metadata.is_outlier() + # If we encountered an event with no prev_events, then we might + # as well remove it now because it won't ever have anything else + # to backfill from. + or len(ev.prev_event_ids()) == 0 ], ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index e01c94930a..92539f5d41 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -42,7 +42,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def _load_rules(rawrules, enabled_map, use_new_defaults=False): +def _load_rules(rawrules, enabled_map): ruleslist = [] for rawrule in rawrules: rule = dict(rawrule) @@ -52,7 +52,7 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False): ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy - rules = list(list_with_base_rules(ruleslist, use_new_defaults)) + rules = list(list_with_base_rules(ruleslist)) for i, rule in enumerate(rules): rule_id = rule["rule_id"] @@ -112,10 +112,6 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) - self._users_new_default_push_rules = ( - hs.config.server.users_new_default_push_rules - ) - @abc.abstractmethod def get_max_push_rules_stream_id(self): """Get the position of the push rules stream. @@ -145,9 +141,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - use_new_defaults = user_id in self._users_new_default_push_rules - - return _load_rules(rows, enabled_map, use_new_defaults) + return _load_rules(rows, enabled_map) @cached(max_entries=5000) async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]: @@ -206,13 +200,7 @@ class PushRulesWorkerStore( enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): - use_new_defaults = user_id in self._users_new_default_push_rules - - results[user_id] = _load_rules( - rules, - enabled_map_by_user.get(user_id, {}), - use_new_defaults, - ) + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) return results diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 37468a5183..e2c27e594b 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,12 +13,23 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, + cast, +) import attr from frozendict import frozendict -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -28,16 +39,14 @@ from synapse.storage.database import ( make_in_list_sql_clause, ) from synapse.storage.databases.main.stream import generate_pagination_where_clause -from synapse.storage.relations import ( - AggregationPaginationToken, - PaginationChunk, - RelationPaginationToken, -) -from synapse.types import JsonDict -from synapse.util.caches.descriptors import cached +from synapse.storage.engines import PostgresEngine +from synapse.storage.relations import AggregationPaginationToken, PaginationChunk +from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -87,8 +96,8 @@ class RelationsWorkerStore(SQLBaseStore): aggregation_key: Optional[str] = None, limit: int = 5, direction: str = "b", - from_token: Optional[RelationPaginationToken] = None, - to_token: Optional[RelationPaginationToken] = None, + from_token: Optional[StreamToken] = None, + to_token: Optional[StreamToken] = None, ) -> PaginationChunk: """Get a list of relations for an event, ordered by topological ordering. @@ -127,8 +136,10 @@ class RelationsWorkerStore(SQLBaseStore): pagination_clause = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] - to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] + from_token=from_token.room_key.as_historical_tuple() + if from_token + else None, + to_token=to_token.room_key.as_historical_tuple() if to_token else None, engine=self.database_engine, ) @@ -166,12 +177,27 @@ class RelationsWorkerStore(SQLBaseStore): last_topo_id = row[1] last_stream_id = row[2] - next_batch = None + # If there are more events, generate the next pagination key. + next_token = None if len(events) > limit and last_topo_id and last_stream_id: - next_batch = RelationPaginationToken(last_topo_id, last_stream_id) + next_key = RoomStreamToken(last_topo_id, last_stream_id) + if from_token: + next_token = from_token.copy_and_replace("room_key", next_key) + else: + next_token = StreamToken( + room_key=next_key, + presence_key=0, + typing_key=0, + receipt_key=0, + account_data_key=0, + push_rules_key=0, + to_device_key=0, + device_list_key=0, + groups_key=0, + ) return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token ) return await self.db_pool.runInteraction( @@ -340,20 +366,24 @@ class RelationsWorkerStore(SQLBaseStore): ) @cached() - async def get_applicable_edit( - self, event_id: str, room_id: str - ) -> Optional[EventBase]: + def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: + raise NotImplementedError() + + @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") + async def _get_applicable_edits( + self, event_ids: Collection[str] + ) -> Dict[str, Optional[EventBase]]: """Get the most recent edit (if any) that has happened for the given - event. + events. Correctly handles checking whether edits were allowed to happen. Args: - event_id: The original event ID - room_id: The original event's room ID + event_ids: The original event IDs Returns: - The most recent edit, if any. + A map of the most recent edit for each event. If there are no edits, + the event will map to None. """ # We only allow edits for `m.room.message` events that have the same sender @@ -362,139 +392,238 @@ class RelationsWorkerStore(SQLBaseStore): # Fetches latest edit that has the same type and sender as the # original, and is an `m.room.message`. - sql = """ - SELECT edit.event_id FROM events AS edit - INNER JOIN event_relations USING (event_id) - INNER JOIN events AS original ON - original.event_id = relates_to_id - AND edit.type = original.type - AND edit.sender = original.sender - WHERE - relates_to_id = ? - AND relation_type = ? - AND edit.room_id = ? - AND edit.type = 'm.room.message' - ORDER by edit.origin_server_ts DESC, edit.event_id DESC - LIMIT 1 - """ + if isinstance(self.database_engine, PostgresEngine): + # The `DISTINCT ON` clause will pick the *first* row it encounters, + # so ordering by origin server ts + event ID desc will ensure we get + # the latest edit. + sql = """ + SELECT DISTINCT ON (original.event_id) original.event_id, edit.event_id FROM events AS edit + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS original ON + original.event_id = relates_to_id + AND edit.type = original.type + AND edit.sender = original.sender + AND edit.room_id = original.room_id + WHERE + %s + AND relation_type = ? + AND edit.type = 'm.room.message' + ORDER by original.event_id DESC, edit.origin_server_ts DESC, edit.event_id DESC + """ + else: + # SQLite uses a simplified query which returns all edits for an + # original event. The results are then de-duplicated when turned into + # a dict. Due to the chosen ordering, the latest edit stomps on + # earlier edits. + sql = """ + SELECT original.event_id, edit.event_id FROM events AS edit + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS original ON + original.event_id = relates_to_id + AND edit.type = original.type + AND edit.sender = original.sender + AND edit.room_id = original.room_id + WHERE + %s + AND relation_type = ? + AND edit.type = 'm.room.message' + ORDER by edit.origin_server_ts, edit.event_id + """ - def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: - txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id)) - row = txn.fetchone() - if row: - return row[0] - return None + def _get_applicable_edits_txn(txn: LoggingTransaction) -> Dict[str, str]: + clause, args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.REPLACE) - edit_id = await self.db_pool.runInteraction( - "get_applicable_edit", _get_applicable_edit_txn + txn.execute(sql % (clause,), args) + return dict(cast(Iterable[Tuple[str, str]], txn.fetchall())) + + edit_ids = await self.db_pool.runInteraction( + "get_applicable_edits", _get_applicable_edits_txn ) - if not edit_id: - return None + edits = await self.get_events(edit_ids.values()) # type: ignore[attr-defined] - return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined] + # Map to the original event IDs to the edit events. + # + # There might not be an edit event due to there being no edits or + # due to the event not being known, either case is treated the same. + return { + original_event_id: edits.get(edit_ids.get(original_event_id)) + for original_event_id in event_ids + } @cached() - async def get_thread_summary( - self, event_id: str, room_id: str - ) -> Tuple[int, Optional[EventBase]]: + def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]: + raise NotImplementedError() + + @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") + async def _get_thread_summaries( + self, event_ids: Collection[str] + ) -> Dict[str, Optional[Tuple[int, EventBase]]]: """Get the number of threaded replies and the latest reply (if any) for the given event. Args: - event_id: Summarize the thread related to this event ID. - room_id: The room the event belongs to. + event_ids: Summarize the thread related to this event ID. Returns: - The number of items in the thread and the most recent response, if any. + A map of the thread summary each event. A missing event implies there + are no threaded replies. + + Each summary includes the number of items in the thread and the most + recent response. """ - def _get_thread_summary_txn( + def _get_thread_summaries_txn( txn: LoggingTransaction, - ) -> Tuple[int, Optional[str]]: - # Fetch the latest event ID in the thread. + ) -> Tuple[Dict[str, int], Dict[str, str]]: + # Fetch the count of threaded events and the latest event ID. # TODO Should this only allow m.room.message events. - sql = """ - SELECT event_id - FROM event_relations - INNER JOIN events USING (event_id) - WHERE - relates_to_id = ? - AND room_id = ? - AND relation_type = ? - ORDER BY topological_ordering DESC, stream_ordering DESC - LIMIT 1 - """ + if isinstance(self.database_engine, PostgresEngine): + # The `DISTINCT ON` clause will pick the *first* row it encounters, + # so ordering by topologica ordering + stream ordering desc will + # ensure we get the latest event in the thread. + sql = """ + SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = child.room_id + WHERE + %s + AND relation_type = ? + ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC + """ + else: + # SQLite uses a simplified query which returns all entries for a + # thread. The first result for each thread is chosen to and subsequent + # results for a thread are ignored. + sql = """ + SELECT parent.event_id, child.event_id FROM events AS child + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = child.room_id + WHERE + %s + AND relation_type = ? + ORDER BY child.topological_ordering DESC, child.stream_ordering DESC + """ + + clause, args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.THREAD) - txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) - row = txn.fetchone() - if row is None: - return 0, None + txn.execute(sql % (clause,), args) + latest_event_ids = {} + for parent_event_id, child_event_id in txn: + # Only consider the latest threaded reply (by topological ordering). + if parent_event_id not in latest_event_ids: + latest_event_ids[parent_event_id] = child_event_id - latest_event_id = row[0] + # If no threads were found, bail. + if not latest_event_ids: + return {}, latest_event_ids # Fetch the number of threaded replies. sql = """ - SELECT COUNT(event_id) - FROM event_relations - INNER JOIN events USING (event_id) + SELECT parent.event_id, COUNT(child.event_id) FROM events AS child + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = child.room_id WHERE - relates_to_id = ? - AND room_id = ? + %s AND relation_type = ? + GROUP BY parent.event_id """ - txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) - count = cast(Tuple[int], txn.fetchone())[0] - return count, latest_event_id + # Regenerate the arguments since only threads found above could + # possibly have any replies. + clause, args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", latest_event_ids.keys() + ) + args.append(RelationTypes.THREAD) - count, latest_event_id = await self.db_pool.runInteraction( - "get_thread_summary", _get_thread_summary_txn + txn.execute(sql % (clause,), args) + counts = dict(cast(List[Tuple[str, int]], txn.fetchall())) + + return counts, latest_event_ids + + counts, latest_event_ids = await self.db_pool.runInteraction( + "get_thread_summaries", _get_thread_summaries_txn ) - latest_event = None - if latest_event_id: - latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined] + latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] - return count, latest_event + # Map to the event IDs to the thread summary. + # + # There might not be a summary due to there not being a thread or + # due to the latest event not being known, either case is treated the same. + summaries = {} + for parent_event_id, latest_event_id in latest_event_ids.items(): + latest_event = latest_events.get(latest_event_id) + + summary = None + if latest_event: + summary = (counts[parent_event_id], latest_event) + summaries[parent_event_id] = summary + + return summaries @cached() - async def get_thread_participated( - self, event_id: str, room_id: str, user_id: str - ) -> bool: - """Get whether the requesting user participated in a thread. + def get_thread_participated(self, event_id: str, user_id: str) -> bool: + raise NotImplementedError() + + @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") + async def _get_threads_participated( + self, event_ids: Collection[str], user_id: str + ) -> Dict[str, bool]: + """Get whether the requesting user participated in the given threads. - This is separate from get_thread_summary since that can be cached across - all users while this value is specific to the requeser. + This is separate from get_thread_summaries since that can be cached across + all users while this value is specific to the requester. Args: - event_id: The thread related to this event ID. - room_id: The room the event belongs to. + event_ids: The thread related to these event IDs. user_id: The user requesting the summary. Returns: - True if the requesting user participated in the thread, otherwise false. + A map of event ID to a boolean which represents if the requesting + user participated in that event's thread, otherwise false. """ - def _get_thread_summary_txn(txn: LoggingTransaction) -> bool: + def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]: # Fetch whether the requester has participated or not. sql = """ - SELECT 1 - FROM event_relations - INNER JOIN events USING (event_id) + SELECT DISTINCT relates_to_id + FROM events AS child + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = child.room_id WHERE - relates_to_id = ? - AND room_id = ? + %s AND relation_type = ? - AND sender = ? + AND child.sender = ? """ - txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id)) - return bool(txn.fetchone()) + clause, args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", event_ids + ) + args.extend((RelationTypes.THREAD, user_id)) - return await self.db_pool.runInteraction( + txn.execute(sql % (clause,), args) + return {row[0] for row in txn.fetchall()} + + participated_threads = await self.db_pool.runInteraction( "get_thread_summary", _get_thread_summary_txn ) + return {event_id: event_id in participated_threads for event_id in event_ids} + async def events_have_relations( self, parent_ids: List[str], @@ -612,9 +741,6 @@ class RelationsWorkerStore(SQLBaseStore): The bundled aggregations for an event, if bundled aggregations are enabled and the event can have bundled aggregations. """ - # State events and redacted events do not get bundled aggregations. - if event.is_state() or event.internal_metadata.is_redacted(): - return None # Do not bundle aggregations for an event which represents an edit or an # annotation. It does not make sense for them to have related events. @@ -634,43 +760,21 @@ class RelationsWorkerStore(SQLBaseStore): annotations = await self.get_aggregation_groups_for_event(event_id, room_id) if annotations.chunk: - aggregations.annotations = annotations.to_dict() + aggregations.annotations = await annotations.to_dict( + cast("DataStore", self) + ) references = await self.get_relations_for_event( event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: - aggregations.references = references.to_dict() - - edit = None - if event.type == EventTypes.Message: - edit = await self.get_applicable_edit(event_id, room_id) - - if edit: - aggregations.replace = edit - - # If this event is the start of a thread, include a summary of the replies. - if self._msc3440_enabled: - thread_count, latest_thread_event = await self.get_thread_summary( - event_id, room_id - ) - participated = await self.get_thread_participated( - event_id, room_id, user_id - ) - if latest_thread_event: - aggregations.thread = _ThreadAggregation( - latest_event=latest_thread_event, - count=thread_count, - current_user_participated=participated, - ) + aggregations.references = await references.to_dict(cast("DataStore", self)) # Store the bundled aggregations in the event metadata for later use. return aggregations async def get_bundled_aggregations( - self, - events: Iterable[EventBase], - user_id: str, + self, events: Iterable[EventBase], user_id: str ) -> Dict[str, BundledAggregations]: """Generate bundled aggregations for events. @@ -682,14 +786,59 @@ class RelationsWorkerStore(SQLBaseStore): A map of event ID to the bundled aggregation for the event. Not all events may have bundled aggregations in the results. """ + # The already processed event IDs. Tracked separately from the result + # since the result omits events which do not have bundled aggregations. + seen_event_ids = set() - # TODO Parallelize. - results = {} + # State events and redacted events do not get bundled aggregations. + events = [ + event + for event in events + if not event.is_state() and not event.internal_metadata.is_redacted() + ] + + # event ID -> bundled aggregation in non-serialized form. + results: Dict[str, BundledAggregations] = {} + + # Fetch other relations per event. for event in events: + # De-duplicate events by ID to handle the same event requested multiple + # times. The caches that _get_bundled_aggregation_for_event use should + # capture this, but best to reduce work. + if event.event_id in seen_event_ids: + continue + seen_event_ids.add(event.event_id) + event_result = await self._get_bundled_aggregation_for_event(event, user_id) if event_result: results[event.event_id] = event_result + # Fetch any edits. + edits = await self._get_applicable_edits(seen_event_ids) + for event_id, edit in edits.items(): + results.setdefault(event_id, BundledAggregations()).replace = edit + + # Fetch thread summaries. + if self._msc3440_enabled: + summaries = await self._get_thread_summaries(seen_event_ids) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._get_threads_participated( + summaries.keys(), user_id + ) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) + return results |