summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/__init__.py48
-rw-r--r--synapse/storage/databases/main/account_data.py38
-rw-r--r--synapse/storage/databases/main/appservice.py29
-rw-r--r--synapse/storage/databases/main/cache.py11
-rw-r--r--synapse/storage/databases/main/deviceinbox.py92
-rw-r--r--synapse/storage/databases/main/devices.py354
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py8
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py32
-rw-r--r--synapse/storage/databases/main/event_federation.py55
-rw-r--r--synapse/storage/databases/main/event_push_actions.py324
-rw-r--r--synapse/storage/databases/main/events.py133
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py16
-rw-r--r--synapse/storage/databases/main/events_worker.py203
-rw-r--r--synapse/storage/databases/main/filtering.py4
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py8
-rw-r--r--synapse/storage/databases/main/push_rule.py41
-rw-r--r--synapse/storage/databases/main/pusher.py39
-rw-r--r--synapse/storage/databases/main/receipts.py164
-rw-r--r--synapse/storage/databases/main/registration.py183
-rw-r--r--synapse/storage/databases/main/relations.py554
-rw-r--r--synapse/storage/databases/main/room.py300
-rw-r--r--synapse/storage/databases/main/room_batch.py2
-rw-r--r--synapse/storage/databases/main/roommember.py53
-rw-r--r--synapse/storage/databases/main/search.py374
-rw-r--r--synapse/storage/databases/main/state.py2
-rw-r--r--synapse/storage/databases/main/stream.py60
-rw-r--r--synapse/storage/databases/main/user_directory.py80
27 files changed, 2430 insertions, 777 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py

index a62b4abd4e..0e47592be3 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -26,9 +26,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict, get_domain_from_id -from synapse.util.caches.stream_change_cache import StreamChangeCache from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore @@ -138,41 +136,8 @@ class DataStore( self._clock = hs.get_clock() self.database_engine = database.engine - self._device_list_id_gen = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) - super().__init__(database, db_conn, hs) - events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( - db_conn, - "current_state_delta_stream", - entity_column="room_id", - stream_column="stream_id", - max_value=events_max, # As we share the stream id with events token - limit=1000, - ) - self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", - min_curr_state_delta_id, - prefilled_cache=curr_state_delta_prefill, - ) - - self._stream_order_on_start = self.get_room_max_stream_ordering() - self._min_stream_order_on_start = self.get_room_min_stream_ordering() - - def get_device_stream_token(self) -> int: - # TODO: shouldn't this be moved to `DeviceWorkerStore`? - return self._device_list_id_gen.get_current_token() - async def get_users(self) -> List[JsonDict]: """Function to retrieve a list of users in users table. @@ -201,7 +166,7 @@ class DataStore( name: Optional[str] = None, guests: bool = True, deactivated: bool = False, - order_by: str = UserSortOrder.USER_ID.value, + order_by: str = UserSortOrder.NAME.value, direction: str = "f", approved: bool = True, ) -> Tuple[List[JsonDict], int]: @@ -261,6 +226,7 @@ class DataStore( sql_base = f""" FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? + LEFT JOIN erased_users AS eu ON u.name = eu.user_id {where_clause} """ sql = "SELECT COUNT(*) as total_users " + sql_base @@ -269,7 +235,8 @@ class DataStore( sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, - displayname, avatar_url, creation_ts * 1000 as creation_ts, approved + displayname, avatar_url, creation_ts * 1000 as creation_ts, approved, + eu.user_id is not null as erased {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? @@ -277,6 +244,13 @@ class DataStore( args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) + + # some of those boolean values are returned as integers when we're on SQLite + columns_to_boolify = ["erased"] + for user in users: + for column in columns_to_boolify: + user[column] = bool(user[column]) + return users, count return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index c38b8a9e5a..07908c41d9 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py
@@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -68,12 +67,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. self._account_data_id_gen: AbstractStreamIdTracker + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) if isinstance(database.engine, PostgresEngine): - self._can_write_to_account_data = ( - self._instance_name in hs.config.worker.writers.account_data - ) - self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, @@ -95,21 +93,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if self._instance_name in hs.config.worker.writers.account_data: - self._can_write_to_account_data = True - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) - else: - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + is_writer=self._instance_name in hs.config.worker.writers.account_data, + ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( @@ -459,9 +449,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) content_json = json_encoder.encode(content) async with self._account_data_id_gen.get_next() as next_id: - # no need to lock here as room_account_data has a unique constraint - # on (user_id, room_id, account_data_type) so simple_upsert will - # retry if there is a conflict. await self.db_pool.simple_upsert( desc="add_room_account_data", table="room_account_data", @@ -471,7 +458,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) "account_data_type": account_data_type, }, values={"stream_id": next_id, "content": content_json}, - lock=False, ) self._account_data_stream_cache.entity_has_changed(user_id, next_id) @@ -527,15 +513,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) -> None: content_json = json_encoder.encode(content) - # no need to lock here as account_data has a unique constraint on - # (user_id, account_data_type) so simple_upsert will retry if - # there is a conflict. self.db_pool.simple_upsert_txn( txn, table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, values={"stream_id": next_id, "content": content_json}, - lock=False, ) # Ignored users get denormalized into a separate table as an optimisation. diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 64b70a7b28..c2c8018ee2 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -20,7 +20,7 @@ from synapse.appservice import ( ApplicationService, ApplicationServiceState, AppServiceTransaction, - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.config.appservice import load_appservices @@ -157,10 +157,23 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): app_service: "ApplicationService", cache_context: _CacheContext, ) -> List[str]: - users_in_room = await self.get_users_in_room( + """ + Get all users in a room that the appservice controls. + + Args: + room_id: The room to check in. + app_service: The application service to check interest/control against + + Returns: + List of user IDs that the appservice controls. + """ + # We can use `get_local_users_in_room(...)` here because an application service + # can only be interested in local users of the server it's on (ignore any remote + # users that might match the user namespace regex). + local_users_in_room = await self.get_local_users_in_room( room_id, on_invalidate=cache_context.invalidate ) - return list(filter(app_service.is_interested_in_user, users_in_room)) + return list(filter(app_service.is_interested_in_user, local_users_in_room)) class ApplicationServiceStore(ApplicationServiceWorkerStore): @@ -247,7 +260,7 @@ class ApplicationServiceTransactionWorkerStore( events: List[EventBase], ephemeral: List[JsonDict], to_device_messages: List[JsonDict], - one_time_key_counts: TransactionOneTimeKeyCounts, + one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, ) -> AppServiceTransaction: @@ -260,7 +273,7 @@ class ApplicationServiceTransactionWorkerStore( events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. to_device_messages: A list of to-device messages to put in the transaction. - one_time_key_counts: Counts of remaining one-time keys for relevant + one_time_keys_count: Counts of remaining one-time keys for relevant appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. @@ -286,7 +299,7 @@ class ApplicationServiceTransactionWorkerStore( events=events, ephemeral=ephemeral, to_device_messages=to_device_messages, - one_time_key_counts=one_time_key_counts, + one_time_keys_count=one_time_keys_count, unused_fallback_keys=unused_fallback_keys, device_list_summary=device_list_summary, ) @@ -366,7 +379,7 @@ class ApplicationServiceTransactionWorkerStore( events=events, ephemeral=[], to_device_messages=[], - one_time_key_counts={}, + one_time_keys_count={}, unused_fallback_keys={}, device_list_summary=DeviceListUpdates(), ) @@ -438,8 +451,6 @@ class ApplicationServiceTransactionWorkerStore( table="application_services_state", keyvalues={"as_id": service.id}, values={f"{stream_type}_stream_id": pos}, - # no need to lock when emulating upsert: as_id is a unique key - lock=False, desc="set_appservice_stream_type_pos", ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 3b8ed1f7ee..a58668a380 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py
@@ -244,24 +244,29 @@ class CacheInvalidationWorkerStore(SQLBaseStore): # redacted. self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) self._attempt_to_invalidate_cache( "get_invited_rooms_for_local_user", (state_key,) ) + self._attempt_to_invalidate_cache( + "get_rooms_for_user_with_stream_ordering", (state_key,) + ) + self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,)) if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) + self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) self._attempt_to_invalidate_cache( "get_aggregation_groups_for_event", (relates_to,) ) self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) - self._attempt_to_invalidate_cache( - "get_mutual_event_relations_for_rel_type", (relates_to,) - ) + self._attempt_to_invalidate_cache("get_threads", (room_id,)) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 73c95ffb6f..48a54d9cb8 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -26,8 +26,15 @@ from typing import ( cast, ) +from synapse.api.constants import EventContentFields from synapse.logging import issue9533_logger -from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.logging.opentracing import ( + SynapseTags, + log_kv, + set_tag, + start_active_span, + trace, +) from synapse.replication.tcp.streams import ToDeviceStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -397,6 +404,17 @@ class DeviceInboxWorkerStore(SQLBaseStore): (recipient_user_id, recipient_device_id), [] ).append(message_dict) + # start a new span for each message, so that we can tag each separately + with start_active_span("get_to_device_message"): + set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"]) + set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, recipient_user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, recipient_device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID), + ) + if limit is not None and rowcount == limit: # We ended up bumping up against the message limit. There may be more messages # to retrieve. Return what we have, as well as the last stream position that @@ -678,12 +696,35 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) - if remote_messages_by_destination: - issue9533_logger.debug( - "Queued outgoing to-device messages with stream_id %i for %s", - stream_id, - list(remote_messages_by_destination.keys()), - ) + for destination, edu in remote_messages_by_destination.items(): + if issue9533_logger.isEnabledFor(logging.DEBUG): + issue9533_logger.debug( + "Queued outgoing to-device messages with " + "stream_id %i, EDU message_id %s, type %s for %s: %s", + stream_id, + edu["message_id"], + edu["type"], + destination, + [ + f"{user_id}/{device_id} (msgid " + f"{msg.get(EventContentFields.TO_DEVICE_MSGID)})" + for (user_id, messages_by_device) in edu["messages"].items() + for (device_id, msg) in messages_by_device.items() + ], + ) + + for (user_id, messages_by_device) in edu["messages"].items(): + for (device_id, msg) in messages_by_device.items(): + with start_active_span("store_outgoing_to_device_message"): + set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"]) + set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"]) + set_tag(SynapseTags.TO_DEVICE_TYPE, edu["type"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + msg.get(EventContentFields.TO_DEVICE_MSGID), + ) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self._clock.time_msec() @@ -801,7 +842,19 @@ class DeviceInboxWorkerStore(SQLBaseStore): # Only insert into the local inbox if the device exists on # this server device_id = row["device_id"] - message_json = json_encoder.encode(messages_by_device[device_id]) + + with start_active_span("serialise_to_device_message"): + msg = messages_by_device[device_id] + set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"]) + set_tag(SynapseTags.TO_DEVICE_SENDER, msg["sender"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + msg["content"].get(EventContentFields.TO_DEVICE_MSGID), + ) + message_json = json_encoder.encode(msg) + messages_json_for_user[device_id] = message_json if messages_json_for_user: @@ -821,15 +874,20 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) - issue9533_logger.debug( - "Stored to-device messages with stream_id %i for %s", - stream_id, - [ - (user_id, device_id) - for (user_id, messages_by_device) in local_by_user_then_device.items() - for device_id in messages_by_device.keys() - ], - ) + if issue9533_logger.isEnabledFor(logging.DEBUG): + issue9533_logger.debug( + "Stored to-device messages with stream_id %i: %s", + stream_id, + [ + f"{user_id}/{device_id} (msgid " + f"{msg['content'].get(EventContentFields.TO_DEVICE_MSGID)})" + for ( + user_id, + messages_by_device, + ) in messages_by_user_then_device.items() + for (device_id, msg) in messages_by_device.items() + ], + ) class DeviceInboxBackgroundUpdateStore(SQLBaseStore): diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 18358eca46..a5bb4d404e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import ( TYPE_CHECKING, @@ -39,6 +38,7 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -49,11 +49,19 @@ from synapse.storage.database import ( from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + AbstractStreamIdTracker, + StreamIdGenerator, +) from synapse.types import JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.caches.stream_change_cache import ( + AllEntitiesChangedResult, + StreamChangeCache, +) from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -80,9 +88,23 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ): super().__init__(database, db_conn, hs) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + is_writer=hs.config.worker.worker_app is None, + ) + # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). - device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined] + device_list_max = self._device_list_id_gen.get_current_token() device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict( db_conn, "device_lists_stream", @@ -136,6 +158,39 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): self._prune_old_outbound_device_pokes, 60 * 60 * 1000 ) + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: + if stream_name == DeviceListsStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + self._invalidate_caches_for_devices(token, rows) + elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + for row in rows: + self._user_signature_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) + + def _invalidate_caches_for_devices( + self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] + ) -> None: + for row in rows: + # The entities are either user IDs (starting with '@') whose devices + # have changed, or remote servers that we need to tell about + # changes. + if row.entity.startswith("@"): + self._device_list_stream_cache.entity_has_changed(row.entity, token) + self.get_cached_devices_for_user.invalidate((row.entity,)) + self._get_cached_user_device.invalidate((row.entity,)) + self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) + + else: + self._device_list_federation_stream_cache.entity_has_changed( + row.entity, token + ) + + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int: """Retrieve number of all devices of given users. Only returns number of devices that are not marked as hidden. @@ -274,6 +329,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): destination, int(from_stream_id) ) if not has_changed: + # debugging for https://github.com/matrix-org/synapse/issues/14251 + issue_8631_logger.debug( + "%s: no change between %i and %i", + destination, + from_stream_id, + now_stream_id, + ) return now_stream_id, [] updates = await self.db_pool.runInteraction( @@ -466,7 +528,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): limit: Maximum number of device updates to return Returns: - List: List of device update tuples: + List of device update tuples: - user_id - device_id - stream_id @@ -539,9 +601,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): "device_id": device_id, "prev_id": [prev_id] if prev_id else [], "stream_id": stream_id, - "org.matrix.opentracing_context": opentracing_context, } + if opentracing_context != "{}": + result["org.matrix.opentracing_context"] = opentracing_context + prev_id = stream_id if device is not None: @@ -549,7 +613,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): if keys: result["keys"] = keys - device_display_name = device.display_name + device_display_name = None + if ( + self.hs.config.federation.allow_device_name_lookup_over_federation + ): + device_display_name = device.display_name if device_display_name: result["device_display_name"] = device_display_name else: @@ -664,11 +732,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): }, ) - @abc.abstractmethod - def get_device_stream_token(self) -> int: - """Get the current stream id from the _device_list_id_gen""" - ... - @trace @cancellable async def get_user_devices_from_cache( @@ -739,7 +802,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def get_cached_device_list_changes( self, from_key: int, - ) -> Optional[List[str]]: + ) -> AllEntitiesChangedResult: """Get set of users whose devices have changed since `from_key`, or None if that information is not in our cache. """ @@ -747,10 +810,58 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return self._device_list_stream_cache.get_all_entities_changed(from_key) @cancellable + async def get_all_devices_changed( + self, + from_key: int, + to_key: int, + ) -> Set[str]: + """Get all users whose devices have changed in the given range. + + Args: + from_key: The minimum device lists stream token to query device list + changes for, exclusive. + to_key: The maximum device lists stream token to query device list + changes for, inclusive. + + Returns: + The set of user_ids whose devices have changed since `from_key` + (exclusive) until `to_key` (inclusive). + """ + + result = self._device_list_stream_cache.get_all_entities_changed(from_key) + + if result.hit: + # We know which users might have changed devices. + if not result.entities: + # If no users then we can return early. + return set() + + # Otherwise we need to filter down the list + return await self.get_users_whose_devices_changed( + from_key, result.entities, to_key + ) + + # If the cache didn't tell us anything, we just need to query the full + # range. + sql = """ + SELECT DISTINCT user_id FROM device_lists_stream + WHERE ? < stream_id AND stream_id <= ? + """ + + rows = await self.db_pool.execute( + "get_all_devices_changed", + None, + sql, + from_key, + to_key, + ) + return {u for u, in rows} + + @cancellable async def get_users_whose_devices_changed( self, from_key: int, - user_ids: Optional[Collection[str]] = None, + user_ids: Collection[str], to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that @@ -770,46 +881,31 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): """ # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - user_ids_to_check: Optional[Collection[str]] - if user_ids is None: - # Get set of all users that have had device list changes since 'from_key' - user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( - from_key - ) - else: - # The same as above, but filter results to only those users in 'user_ids' - user_ids_to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key - ) + user_ids_to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key + ) + # If an empty set was returned, there's nothing to do. if not user_ids_to_check: return set() - def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: - changes: Set[str] = set() - - stream_id_where_clause = "stream_id > ?" - sql_args = [from_key] - - if to_key: - stream_id_where_clause += " AND stream_id <= ?" - sql_args.append(to_key) + if to_key is None: + to_key = self._device_list_id_gen.get_current_token() - sql = f""" + def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: + sql = """ SELECT DISTINCT user_id FROM device_lists_stream - WHERE {stream_id_where_clause} - AND + WHERE ? < stream_id AND stream_id <= ? AND %s """ + changes: Set[str] = set() + # Query device changes with a batch of users at a time - # Assertion for mypy's benefit; see also - # https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions - assert user_ids_to_check is not None for chunk in batch_iter(user_ids_to_check, 100): clause, args = make_in_list_sql_clause( txn.database_engine, "user_id", chunk ) - txn.execute(sql + clause, sql_args + args) + txn.execute(sql % (clause,), [from_key, to_key] + args) changes.update(user_id for user_id, in txn) return changes @@ -1381,6 +1477,13 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): self._remove_duplicate_outbound_pokes, ) + self.db_pool.updates.register_background_index_update( + "device_lists_changes_in_room_by_room_index", + index_name="device_lists_changes_in_room_by_room_idx", + table="device_lists_changes_in_room", + columns=["room_id", "stream_id"], + ) + async def _drop_device_list_streams_non_unique_indexes( self, progress: JsonDict, batch_size: int ) -> int: @@ -1468,6 +1571,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): + # Because we have write access, this will be a StreamIdGenerator + # (see DeviceWorkerStore.__init__) + _device_list_id_gen: AbstractStreamIdGenerator + def __init__( self, database: DatabasePool, @@ -1673,9 +1780,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, values={"content": json_encoder.encode(content)}, - # we don't need to lock, because we assume we are the only thread - # updating this user's devices. - lock=False, ) txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id)) @@ -1689,9 +1793,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, values={"stream_id": stream_id}, - # again, we can assume we are the only thread updating this user's - # extremity. - lock=False, ) async def update_remote_device_list_cache( @@ -1744,9 +1845,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, values={"stream_id": stream_id}, - # we don't need to lock, because we can assume we are the only thread - # updating this user's extremity. - lock=False, ) async def add_device_change_to_streams( @@ -1792,7 +1890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context, ) - async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined] + async with self._device_list_id_gen.get_next_mult( len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( @@ -1842,7 +1940,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn: LoggingTransaction, user_id: str, - device_ids: Iterable[str], + device_id: str, hosts: Collection[str], stream_ids: List[int], context: Optional[Dict[str, str]], @@ -1858,6 +1956,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): stream_id_iterator = iter(stream_ids) encoded_context = json_encoder.encode(context) + mark_sent = not self.hs.is_mine_id(user_id) + + values = [ + ( + destination, + next(stream_id_iterator), + user_id, + device_id, + mark_sent, + now, + encoded_context if whitelisted_homeserver(destination) else "{}", + ) + for destination in hosts + ] + self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", @@ -1870,23 +1983,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "ts", "opentracing_context", ), - values=[ - ( - destination, - next(stream_id_iterator), - user_id, - device_id, - not self.hs.is_mine_id( - user_id - ), # We only need to send out update for *our* users - now, - encoded_context if whitelisted_homeserver(destination) else "{}", - ) - for destination in hosts - for device_id in device_ids - ], + values=values, ) + # debugging for https://github.com/matrix-org/synapse/issues/14251 + if issue_8631_logger.isEnabledFor(logging.DEBUG): + issue_8631_logger.debug( + "Recorded outbound pokes for %s:%s with device stream ids %s", + user_id, + device_id, + { + stream_id: destination + for (destination, stream_id, _, _, _, _, _) in values + }, + ) + def _add_device_outbound_room_poke_txn( self, txn: LoggingTransaction, @@ -1931,27 +2042,48 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def get_uncoverted_outbound_room_pokes( - self, limit: int = 10 + self, start_stream_id: int, start_room_id: str, limit: int = 10 ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: """Get device list changes by room that have not yet been handled and written to `device_lists_outbound_pokes`. + Args: + start_stream_id: Together with `start_room_id`, indicates the position after + which to return device list changes. + start_room_id: Together with `start_stream_id`, indicates the position after + which to return device list changes. + limit: The maximum number of device list changes to return. + Returns: - A list of user ID, device ID, room ID, stream ID and optional opentracing context. + A list of user ID, device ID, room ID, stream ID and optional opentracing + context, in order of ascending (stream ID, room ID). """ sql = """ SELECT user_id, device_id, room_id, stream_id, opentracing_context FROM device_lists_changes_in_room - WHERE NOT converted_to_destinations - ORDER BY stream_id + WHERE + (stream_id, room_id) > (?, ?) AND + stream_id <= ? AND + NOT converted_to_destinations + ORDER BY stream_id ASC, room_id ASC LIMIT ? """ def get_uncoverted_outbound_room_pokes_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: - txn.execute(sql, (limit,)) + txn.execute( + sql, + ( + start_stream_id, + start_room_id, + # Avoid returning rows if there may be uncommitted device list + # changes with smaller stream IDs. + self._device_list_id_gen.get_current_token(), + limit, + ), + ) return [ ( @@ -1973,52 +2105,28 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_id: str, room_id: str, - stream_id: Optional[int], hosts: Collection[str], context: Optional[Dict[str, str]], ) -> None: """Queue the device update to be sent to the given set of hosts, calculated from the room ID. - - Marks the associated row in `device_lists_changes_in_room` as handled, - if `stream_id` is provided. """ + if not hosts: + return def add_device_list_outbound_pokes_txn( txn: LoggingTransaction, stream_ids: List[int] ) -> None: - if hosts: - self._add_device_outbound_poke_to_stream_txn( - txn, - user_id=user_id, - device_ids=[device_id], - hosts=hosts, - stream_ids=stream_ids, - context=context, - ) - - if stream_id: - self.db_pool.simple_update_txn( - txn, - table="device_lists_changes_in_room", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - "stream_id": stream_id, - "room_id": room_id, - }, - updatevalues={"converted_to_destinations": True}, - ) - - if not hosts: - # If there are no hosts then we don't try and generate stream IDs. - return await self.db_pool.runInteraction( - "add_device_list_outbound_pokes", - add_device_list_outbound_pokes_txn, - [], + self._add_device_outbound_poke_to_stream_txn( + txn, + user_id=user_id, + device_id=device_id, + hosts=hosts, + stream_ids=stream_ids, + context=context, ) - async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined] + async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: return await self.db_pool.runInteraction( "add_device_list_outbound_pokes", add_device_list_outbound_pokes_txn, @@ -2032,7 +2140,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): updates during partial joins. """ - async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.simple_upsert( table="device_lists_remote_pending", keyvalues={ @@ -2079,3 +2187,37 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "get_pending_remote_device_list_updates_for_room", get_pending_remote_device_list_updates_for_room_txn, ) + + async def get_device_change_last_converted_pos(self) -> Tuple[int, str]: + """ + Get the position of the last row in `device_list_changes_in_room` that has been + converted to `device_lists_outbound_pokes`. + + Rows with a strictly greater position where `converted_to_destinations` is + `FALSE` have not been converted. + """ + + row = await self.db_pool.simple_select_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + retcols=["stream_id", "room_id"], + desc="get_device_change_last_converted_pos", + ) + return row["stream_id"], row["room_id"] + + async def set_device_change_last_converted_pos( + self, + stream_id: int, + room_id: str, + ) -> None: + """ + Set the position of the last row in `device_list_changes_in_room` that has been + converted to `device_lists_outbound_pokes`. + """ + + await self.db_pool.simple_update_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + updatevalues={"stream_id": stream_id, "room_id": room_id}, + desc="set_device_change_last_converted_pos", + ) diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index af59be6b48..6240f9a75e 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -391,10 +391,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): Returns: A dict giving the info metadata for this backup version, with fields including: - version(str) - algorithm(str) - auth_data(object): opaque dict supplied by the client - etag(int): tag of the keys in the backup + version (str) + algorithm (str) + auth_data (object): opaque dict supplied by the client + etag (int): tag of the keys in the backup """ def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 8a10ae800c..4c691642e2 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -33,7 +33,7 @@ from typing_extensions import Literal from synapse.api.constants import DeviceKeyAlgorithms from synapse.appservice import ( - TransactionOneTimeKeyCounts, + TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys, ) from synapse.logging.opentracing import log_kv, set_tag, trace @@ -139,11 +139,15 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @trace @cancellable async def get_e2e_device_keys_for_cs_api( - self, query_list: List[Tuple[str, Optional[str]]] + self, + query_list: Collection[Tuple[str, Optional[str]]], + include_displaynames: bool = True, ) -> Dict[str, Dict[str, JsonDict]]: """Fetch a list of device keys, formatted suitably for the C/S API. Args: - query_list(list): List of pairs of user_ids and device_ids. + query_list: List of pairs of user_ids and device_ids. + include_displaynames: Whether to include the displayname of returned devices + (if one exists). Returns: Dict mapping from user-id to dict mapping from device_id to key data. The key data will be a dict in the same format as the @@ -166,9 +170,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker continue r["unsigned"] = {} - display_name = device_info.display_name - if display_name is not None: - r["unsigned"]["device_display_name"] = display_name + if include_displaynames: + # Include the device's display name in the "unsigned" dictionary + display_name = device_info.display_name + if display_name is not None: + r["unsigned"]["device_display_name"] = display_name + rv[user_id][device_id] = r return rv @@ -405,10 +412,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """Retrieve a number of one-time keys for a user Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - key_ids(list[str]): list of key ids (excluding algorithm) to - retrieve + user_id: id of user to get keys for + device_id: id of device to get keys for + key_ids: list of key ids (excluding algorithm) to retrieve Returns: A map from (algorithm, key_id) to json string for key @@ -508,7 +514,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def count_bulk_e2e_one_time_keys_for_as( self, user_ids: Collection[str] - ) -> TransactionOneTimeKeyCounts: + ) -> TransactionOneTimeKeysCount: """ Counts, in bulk, the one-time keys for all the users specified. Intended to be used by application services for populating OTK counts in @@ -522,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker def _count_bulk_e2e_one_time_keys_txn( txn: LoggingTransaction, - ) -> TransactionOneTimeKeyCounts: + ) -> TransactionOneTimeKeysCount: user_in_where_clause, user_parameters = make_in_list_sql_clause( self.database_engine, "user_id", user_ids ) @@ -535,7 +541,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """ txn.execute(sql, user_parameters) - result: TransactionOneTimeKeyCounts = {} + result: TransactionOneTimeKeysCount = {} for user_id, device_id, algorithm, count in txn: # We deliberately construct empty dictionaries for diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6b9a629edd..bbee02ab18 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -1501,6 +1501,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas event_id: The event that failed to be fetched or processed cause: The error message or reason that we failed to pull the event """ + logger.debug( + "record_event_failed_pull_attempt room_id=%s, event_id=%s, cause=%s", + room_id, + event_id, + cause, + ) await self.db_pool.runInteraction( "record_event_failed_pull_attempt", self._record_event_failed_pull_attempt_upsert_txn, @@ -1530,6 +1536,54 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause)) + @trace + async def get_event_ids_to_not_pull_from_backoff( + self, + room_id: str, + event_ids: Collection[str], + ) -> List[str]: + """ + Filter down the events to ones that we've failed to pull before recently. Uses + exponential backoff. + + Args: + room_id: The room that the events belong to + event_ids: A list of events to filter down + + Returns: + List of event_ids that should not be attempted to be pulled + """ + event_failed_pull_attempts = await self.db_pool.simple_select_many_batch( + table="event_failed_pull_attempts", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=( + "event_id", + "last_attempt_ts", + "num_attempts", + ), + desc="get_event_ids_to_not_pull_from_backoff", + ) + + current_time = self._clock.time_msec() + return [ + event_failed_pull_attempt["event_id"] + for event_failed_pull_attempt in event_failed_pull_attempts + # Exponential back-off (up to the upper bound) so we don't try to + # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. + if current_time + < event_failed_pull_attempt["last_attempt_ts"] + + ( + 2 + ** min( + event_failed_pull_attempt["num_attempts"], + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + ) + ) + * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS + ] + async def get_missing_events( self, room_id: str, @@ -1632,7 +1686,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas }, insertion_values={}, desc="insert_insertion_extremity", - lock=False, ) async def insert_received_event_to_staging( diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 332e13d1c9..7ebe34f773 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -74,6 +74,7 @@ receipt. """ import logging +from collections import defaultdict from typing import ( TYPE_CHECKING, Collection, @@ -95,6 +96,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + PostgresEngine, ) from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore @@ -294,6 +296,44 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas self._background_backfill_thread_id, ) + # Indexes which will be used to quickly make the thread_id column non-null. + self.db_pool.updates.register_background_index_update( + "event_push_actions_thread_id_null", + index_name="event_push_actions_thread_id_null", + table="event_push_actions", + columns=["thread_id"], + where_clause="thread_id IS NULL", + ) + self.db_pool.updates.register_background_index_update( + "event_push_summary_thread_id_null", + index_name="event_push_summary_thread_id_null", + table="event_push_summary", + columns=["thread_id"], + where_clause="thread_id IS NULL", + ) + + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates the event_push_actions and event_push_summary tables. + self._clock.call_later(0.0, self._check_event_push_backfill_thread_id) + self._event_push_backfill_thread_id_done = False + + @wrap_as_background_process("check_event_push_backfill_thread_id") + async def _check_event_push_backfill_thread_id(self) -> None: + """ + Has thread_id finished backfilling? + + If not, we need to just-in-time update it so the queries work. + """ + done = await self.db_pool.updates.has_completed_background_update( + "event_push_backfill_thread_id" + ) + + if done: + self._event_push_backfill_thread_id_done = True + else: + # Reschedule to run. + self._clock.call_later(15.0, self._check_event_push_backfill_thread_id) + async def _background_backfill_thread_id( self, progress: JsonDict, batch_size: int ) -> int: @@ -310,11 +350,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas event_push_actions_done = progress.get("event_push_actions_done", False) def add_thread_id_txn( - txn: LoggingTransaction, table_name: str, start_stream_ordering: int + txn: LoggingTransaction, start_stream_ordering: int ) -> int: - sql = f""" + sql = """ SELECT stream_ordering - FROM {table_name} + FROM event_push_actions WHERE thread_id IS NULL AND stream_ordering > ? @@ -326,7 +366,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # No more rows to process. rows = txn.fetchall() if not rows: - progress[f"{table_name}_done"] = True + progress["event_push_actions_done"] = True self.db_pool.updates._background_update_progress_txn( txn, "event_push_backfill_thread_id", progress ) @@ -335,16 +375,65 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Update the thread ID for any of those rows. max_stream_ordering = rows[-1][0] - sql = f""" - UPDATE {table_name} + sql = """ + UPDATE event_push_actions SET thread_id = 'main' - WHERE stream_ordering <= ? AND thread_id IS NULL + WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL """ - txn.execute(sql, (max_stream_ordering,)) + txn.execute( + sql, + ( + start_stream_ordering, + max_stream_ordering, + ), + ) # Update progress. processed_rows = txn.rowcount - progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering + progress["max_event_push_actions_stream_ordering"] = max_stream_ordering + self.db_pool.updates._background_update_progress_txn( + txn, "event_push_backfill_thread_id", progress + ) + + return processed_rows + + def add_thread_id_summary_txn(txn: LoggingTransaction) -> int: + min_user_id = progress.get("max_summary_user_id", "") + min_room_id = progress.get("max_summary_room_id", "") + + # Slightly overcomplicated query for getting the Nth user ID / room + # ID tuple, or the last if there are less than N remaining. + sql = """ + SELECT user_id, room_id FROM ( + SELECT user_id, room_id FROM event_push_summary + WHERE (user_id, room_id) > (?, ?) + AND thread_id IS NULL + ORDER BY user_id, room_id + LIMIT ? + ) AS e + ORDER BY user_id DESC, room_id DESC + LIMIT 1 + """ + + txn.execute(sql, (min_user_id, min_room_id, batch_size)) + row = txn.fetchone() + if not row: + return 0 + + max_user_id, max_room_id = row + + sql = """ + UPDATE event_push_summary + SET thread_id = 'main' + WHERE + (?, ?) < (user_id, room_id) AND (user_id, room_id) <= (?, ?) + AND thread_id IS NULL + """ + txn.execute(sql, (min_user_id, min_room_id, max_user_id, max_room_id)) + processed_rows = txn.rowcount + + progress["max_summary_user_id"] = max_user_id + progress["max_summary_room_id"] = max_room_id self.db_pool.updates._background_update_progress_txn( txn, "event_push_backfill_thread_id", progress ) @@ -360,15 +449,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas result = await self.db_pool.runInteraction( "event_push_backfill_thread_id", add_thread_id_txn, - "event_push_actions", progress.get("max_event_push_actions_stream_ordering", 0), ) else: result = await self.db_pool.runInteraction( "event_push_backfill_thread_id", - add_thread_id_txn, - "event_push_summary", - progress.get("max_event_push_summary_stream_ordering", 0), + add_thread_id_summary_txn, ) # Only done after the event_push_summary table is done. @@ -379,6 +465,153 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return result + async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]: + """Get the notification count by room for a user. Only considers notifications, + not highlight or unread counts, and threads are currently aggregated under their room. + + This function is intentionally not cached because it is called to calculate the + unread badge for push notifications and thus the result is expected to change. + + Note that this function assumes the user is a member of the room. Because + summary rows are not removed when a user leaves a room, the caller must + filter out those results from the result. + + Returns: + A map of room ID to notification counts for the given user. + """ + return await self.db_pool.runInteraction( + "get_unread_counts_by_room_for_user", + self._get_unread_counts_by_room_for_user_txn, + user_id, + ) + + def _get_unread_counts_by_room_for_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Dict[str, int]: + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), + ) + args.extend([user_id, user_id]) + + receipts_cte = f""" + WITH all_receipts AS ( + SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + {receipt_types_clause} + AND user_id = ? + GROUP BY room_id, thread_id + ) + """ + + receipts_joins = """ + LEFT JOIN ( + SELECT room_id, thread_id, + max_receipt_stream_ordering AS threaded_receipt_stream_ordering + FROM all_receipts + WHERE thread_id IS NOT NULL + ) AS threaded_receipts USING (room_id, thread_id) + LEFT JOIN ( + SELECT room_id, thread_id, + max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering + FROM all_receipts + WHERE thread_id IS NULL + ) AS unthreaded_receipts USING (room_id) + """ + + # First get summary counts by room / thread for the user. We use the max receipt + # stream ordering of both threaded & unthreaded receipts to compare against the + # summary table. + # + # PostgreSQL and SQLite differ in comparing scalar numerics. + if isinstance(self.database_engine, PostgresEngine): + # GREATEST ignores NULLs. + max_clause = """GREATEST( + threaded_receipt_stream_ordering, + unthreaded_receipt_stream_ordering + )""" + else: + # MAX returns NULL if any are NULL, so COALESCE to 0 first. + max_clause = """MAX( + COALESCE(threaded_receipt_stream_ordering, 0), + COALESCE(unthreaded_receipt_stream_ordering, 0) + )""" + + sql = f""" + {receipts_cte} + SELECT eps.room_id, eps.thread_id, notif_count + FROM event_push_summary AS eps + {receipts_joins} + WHERE user_id = ? + AND notif_count != 0 + AND ( + (last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause}) + OR last_receipt_stream_ordering = {max_clause} + ) + """ + txn.execute(sql, args) + + seen_thread_ids = set() + room_to_count: Dict[str, int] = defaultdict(int) + + for room_id, thread_id, notif_count in txn: + room_to_count[room_id] += notif_count + seen_thread_ids.add(thread_id) + + # Now get any event push actions that haven't been rotated using the same OR + # join and filter by receipt and event push summary rotated up to stream ordering. + sql = f""" + {receipts_cte} + SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count + FROM event_push_actions AS epa + {receipts_joins} + WHERE user_id = ? + AND epa.notif = 1 + AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering) + AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering) + AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering) + GROUP BY epa.room_id, epa.thread_id + """ + txn.execute(sql, args) + + for room_id, thread_id, notif_count in txn: + # Note: only count push actions we have valid summaries for with up to date receipt. + if thread_id not in seen_thread_ids: + continue + room_to_count[room_id] += notif_count + + thread_id_clause, thread_ids_args = make_in_list_sql_clause( + self.database_engine, "epa.thread_id", seen_thread_ids + ) + + # Finally re-check event_push_actions for any rooms not in the summary, ignoring + # the rotated up-to position. This handles the case where a read receipt has arrived + # but not been rotated meaning the summary table is out of date, so we go back to + # the push actions table. + sql = f""" + {receipts_cte} + SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count + FROM event_push_actions AS epa + {receipts_joins} + WHERE user_id = ? + AND NOT {thread_id_clause} + AND epa.notif = 1 + AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering) + AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering) + GROUP BY epa.room_id + """ + + args.extend(thread_ids_args) + txn.execute(sql, args) + + for room_id, notif_count in txn: + room_to_count[room_id] += notif_count + + return room_to_count + @cached(tree=True, max_entries=5000, iterable=True) async def get_unread_event_push_actions_by_room_for_user( self, @@ -480,6 +713,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) + # First ensure that the existing rows have an updated thread_id field. + if not self._event_push_backfill_thread_id_done: + txn.execute( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + txn.execute( + """ + UPDATE event_push_actions + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + # First we pull the counts from the summary table. # # We check that `last_receipt_stream_ordering` matches the stream ordering of the @@ -1295,6 +1547,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas (room_id, user_id, stream_ordering, *thread_args), ) + # First ensure that the existing rows have an updated thread_id field. + if not self._event_push_backfill_thread_id_done: + txn.execute( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + txn.execute( + """ + UPDATE event_push_actions + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + # Fetch the notification counts between the stream ordering of the # latest receipt and what was previously summarised. unread_counts = self._get_notif_unread_count_for_user_room( @@ -1429,6 +1700,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas rotate_to_stream_ordering: The new maximum event stream ordering to summarise. """ + # Ensure that any new actions have an updated thread_id. + if not self._event_push_backfill_thread_id_done: + txn.execute( + """ + UPDATE event_push_actions + SET thread_id = ? + WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL + """, + (MAIN_TIMELINE, old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + + # XXX Do we need to update summaries here too? + # Calculate the new counts that should be upserted into event_push_summary sql = """ SELECT user_id, room_id, thread_id, @@ -1491,6 +1775,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas logger.info("Rotating notifications, handling %d rows", len(summaries)) + # Ensure that any updated threads have the proper thread_id. + if not self._event_push_backfill_thread_id_done: + txn.execute_batch( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + [ + (MAIN_TIMELINE, room_id, user_id) + for user_id, room_id, _ in summaries + ], + ) + self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3e15827986..0f097a2927 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -35,7 +35,7 @@ import attr from prometheus_client import Counter import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event @@ -355,9 +355,9 @@ class PersistEventsStore: txn: LoggingTransaction, *, events_and_contexts: List[Tuple[EventBase, EventContext]], - inhibit_local_membership_updates: bool = False, - state_delta_for_room: Optional[Dict[str, DeltaState]] = None, - new_forward_extremities: Optional[Dict[str, Set[str]]] = None, + inhibit_local_membership_updates: bool, + state_delta_for_room: Dict[str, DeltaState], + new_forward_extremities: Dict[str, Set[str]], ) -> None: """Insert some number of room events into the necessary database tables. @@ -384,9 +384,6 @@ class PersistEventsStore: PartialStateConflictError: if attempting to persist a partial state event in a room that has been un-partial stated. """ - state_delta_for_room = state_delta_for_room or {} - new_forward_extremities = new_forward_extremities or {} - all_events_and_contexts = events_and_contexts min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering @@ -1282,9 +1279,10 @@ class PersistEventsStore: Pick the earliest non-outlier if there is one, else the earliest one. Args: - events_and_contexts (list[(EventBase, EventContext)]): + events_and_contexts: + Returns: - list[(EventBase, EventContext)]: filtered list + filtered list """ new_events_and_contexts: OrderedDict[ str, Tuple[EventBase, EventContext] @@ -1310,9 +1308,8 @@ class PersistEventsStore: """Update min_depth for each room Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting + txn: db connection + events_and_contexts: events we are persisting """ depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: @@ -1583,13 +1580,11 @@ class PersistEventsStore: """Update all the miscellaneous tables for new events Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. + txn: db connection + events_and_contexts: events we are persisting + all_events_and_contexts: all events that we were going to persist. + This includes events we've already persisted, etc, that wouldn't + appear in events_and_context. inhibit_local_membership_updates: Stop the local_current_membership from being updated by these events. This should be set to True for backfilled events because backfilled events in the past do @@ -1616,7 +1611,7 @@ class PersistEventsStore: ) # Remove from relations table. - self._handle_redact_relations(txn, event.redacts) + self._handle_redact_relations(txn, event.room_id, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1866,6 +1861,34 @@ class PersistEventsStore: }, ) + if relation.rel_type == RelationTypes.THREAD: + # Upsert into the threads table, but only overwrite the value if the + # new event is of a later topological order OR if the topological + # ordering is equal, but the stream ordering is later. + sql = """ + INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (room_id, thread_id) + DO UPDATE SET + latest_event_id = excluded.latest_event_id, + topological_ordering = excluded.topological_ordering, + stream_ordering = excluded.stream_ordering + WHERE + threads.topological_ordering <= excluded.topological_ordering AND + threads.stream_ordering < excluded.stream_ordering + """ + + txn.execute( + sql, + ( + event.room_id, + relation.parent_id, + event.event_id, + event.depth, + event.internal_metadata.stream_ordering, + ), + ) + def _handle_insertion_event( self, txn: LoggingTransaction, event: EventBase ) -> None: @@ -1989,35 +2012,52 @@ class PersistEventsStore: txn.execute(sql, (batch_id,)) def _handle_redact_relations( - self, txn: LoggingTransaction, redacted_event_id: str + self, txn: LoggingTransaction, room_id: str, redacted_event_id: str ) -> None: """Handles receiving a redaction and checking whether the redacted event has any relations which must be removed from the database. Args: txn + room_id: The room ID of the event that was redacted. redacted_event_id: The event that was redacted. """ - # Fetch the current relation of the event being redacted. - redacted_relates_to = self.db_pool.simple_select_one_onecol_txn( + # Fetch the relation of the event being redacted. + row = self.db_pool.simple_select_one_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id}, - retcol="relates_to_id", + retcols=("relates_to_id", "relation_type"), allow_none=True, ) + # Nothing to do if no relation is found. + if row is None: + return + + redacted_relates_to = row["relates_to_id"] + rel_type = row["relation_type"] + self.db_pool.simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} + ) + # Any relation information for the related event must be cleared. - if redacted_relates_to is not None: + self.store._invalidate_cache_and_stream( + txn, self.store.get_relations_for_event, (redacted_relates_to,) + ) + if rel_type == RelationTypes.ANNOTATION: self.store._invalidate_cache_and_stream( - txn, self.store.get_relations_for_event, (redacted_relates_to,) + txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REFERENCE: self.store._invalidate_cache_and_stream( - txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) + txn, self.store.get_references_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REPLACE: self.store._invalidate_cache_and_stream( txn, self.store.get_applicable_edit, (redacted_relates_to,) ) + if rel_type == RelationTypes.THREAD: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_summary, (redacted_relates_to,) ) @@ -2025,14 +2065,41 @@ class PersistEventsStore: txn, self.store.get_thread_participated, (redacted_relates_to,) ) self.store._invalidate_cache_and_stream( - txn, - self.store.get_mutual_event_relations_for_rel_type, - (redacted_relates_to,), + txn, self.store.get_threads, (room_id,) ) - self.db_pool.simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) + # Find the new latest event in the thread. + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND relation_type = ? + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT 1 + """ + txn.execute(sql, (redacted_relates_to, RelationTypes.THREAD)) + + # If a latest event is found, update the threads table, this might + # be the same current latest event (if an earlier event in the thread + # was redacted). + latest_event_row = txn.fetchone() + if latest_event_row: + self.db_pool.simple_upsert_txn( + txn, + table="threads", + keyvalues={"room_id": room_id, "thread_id": redacted_relates_to}, + values={ + "latest_event_id": latest_event_row[0], + "topological_ordering": latest_event_row[1], + "stream_ordering": latest_event_row[2], + }, + ) + + # Otherwise, delete the thread: it no longer exists. + else: + self.db_pool.simple_delete_one_txn( + txn, table="threads", keyvalues={"thread_id": redacted_relates_to} + ) def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("topic"), str): diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 6e8aeed7b4..9e31798ab1 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1435,16 +1435,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ), ) - endpoint = None row = txn.fetchone() if row: endpoint = row[0] + else: + # if the query didn't return a row, we must be almost done. We just + # need to go up to the recorded max_stream_ordering. + endpoint = max_stream_ordering_inclusive - where_clause = "stream_ordering > ?" - args = [min_stream_ordering_exclusive] - if endpoint: - where_clause += " AND stream_ordering <= ?" - args.append(endpoint) + where_clause = "stream_ordering > ? AND stream_ordering <= ?" + args = [min_stream_ordering_exclusive, endpoint] # now do the updates. txn.execute( @@ -1458,13 +1458,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) logger.info( - "populated new `events` columns up to %s/%i: updated %i rows", + "populated new `events` columns up to %i/%i: updated %i rows", endpoint, max_stream_ordering_inclusive, txn.rowcount, ) - if endpoint is None: + if endpoint >= max_stream_ordering_inclusive: # we're done return True diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 7cdc9fe98f..318fd7dc71 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -16,11 +16,11 @@ import logging import threading import weakref from enum import Enum, auto +from itertools import chain from typing import ( TYPE_CHECKING, Any, Collection, - Container, Dict, Iterable, List, @@ -59,7 +59,6 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -77,10 +76,12 @@ from synapse.storage.util.id_generators import ( ) from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import AsyncLruCache +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -212,26 +213,35 @@ class EventsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.events: - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - else: - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering" - ) - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) + + events_max = self._stream_id_gen.get_current_token() + curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( + db_conn, + "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache: StreamChangeCache = StreamChangeCache( + "_curr_state_delta_stream_cache", + min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, + ) if hs.config.worker.run_background_tasks: # We periodically clean out old transaction ID mappings @@ -374,7 +384,7 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - The event, or None if the event was not found. + The event, or None if the event was not found and allow_none is `True`. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) @@ -474,7 +484,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = await self._get_events_from_cache_or_db( + event_entry_map = await self.get_unredacted_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -509,7 +519,9 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = await self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self.get_unredacted_events_from_cache_or_db( + [redacted_event_id] + ) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -588,11 +600,16 @@ class EventsWorkerStore(SQLBaseStore): return events @cancellable - async def _get_events_from_cache_or_db( - self, event_ids: Iterable[str], allow_rejected: bool = False + async def get_unredacted_events_from_cache_or_db( + self, + event_ids: Iterable[str], + allow_rejected: bool = False, ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the cache or the database. + Note that the events pulled by this function will not have any redactions + applied, and no guarantee is made about the ordering of the events returned. + If events are pulled from the database, they will be cached for future lookups. Unknown events are omitted from the response. @@ -863,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_stripped_room_state_from_event_context( self, context: EventContext, - state_types_to_include: Container[str], + state_keys_to_include: StateFilter, membership_user_id: Optional[str] = None, ) -> List[JsonDict]: """ @@ -876,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore): Args: context: The event context to retrieve state of the room from. - state_types_to_include: The type of state events to include. + state_keys_to_include: The state events to include, for each event type. membership_user_id: An optional user ID to include the stripped membership state events of. This is useful when generating the stripped state of a room for invites. We want to send membership events of the inviter, so that the @@ -885,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore): Returns: A list of dictionaries, each representing a stripped state event from the room. """ - current_state_ids = await context.get_current_state_ids() + if membership_user_id: + types = chain( + state_keys_to_include.to_types(), + [(EventTypes.Member, membership_user_id)], + ) + filter = StateFilter.from_types(types) + else: + filter = state_keys_to_include + selected_state_ids = await context.get_current_state_ids(filter) # We know this event is not an outlier, so this must be # non-None. - assert current_state_ids is not None - - # The state to include - state_to_include_ids = [ - e_id - for k, e_id in current_state_ids.items() - if k[0] in state_types_to_include - or (membership_user_id and k == (EventTypes.Member, membership_user_id)) - ] + assert selected_state_ids is not None + + # Confusingly, get_current_state_events may return events that are discarded by + # the filter, if they're in context._state_delta_due_to_event. Strip these away. + selected_state_ids = filter.filter_state(selected_state_ids) - state_to_include = await self.get_events(state_to_include_ids) + state_to_include = await self.get_events(selected_state_ids.values()) return [ { @@ -1495,21 +1516,15 @@ class EventsWorkerStore(SQLBaseStore): Returns: a dict {event_id -> bool} """ - # if the event cache contains the event, obviously we've seen it. - - cache_results = { - event_id - for event_id in event_ids - if await self._get_event_cache.contains((event_id,)) - } - results = dict.fromkeys(cache_results, True) - remaining = [ - event_id for event_id in event_ids if event_id not in cache_results - ] - if not remaining: - return results + # TODO: We used to query the _get_event_cache here as a fast-path before + # hitting the database. For if an event were in the cache, we've presumably + # seen it before. + # + # But this is currently an invalid assumption due to the _get_event_cache + # not being invalidated when purging events from a room. The optimisation can + # be re-added after https://github.com/matrix-org/synapse/issues/13476 - def have_seen_events_txn(txn: LoggingTransaction) -> None: + def have_seen_events_txn(txn: LoggingTransaction) -> Dict[str, bool]: # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1517,16 +1532,17 @@ class EventsWorkerStore(SQLBaseStore): sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", remaining + txn.database_engine, "e.event_id", event_ids ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} # ... and then we can update the results for each key - results.update({eid: (eid in found_events) for eid in remaining}) + return {eid: (eid in found_events) for eid in event_ids} - await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) - return results + return await self.db_pool.runInteraction( + "have_seen_events", have_seen_events_txn + ) @cached(max_entries=100000, tree=True) async def have_seen_event(self, room_id: str, event_id: str) -> bool: @@ -1571,7 +1587,7 @@ class EventsWorkerStore(SQLBaseStore): room_id: The room ID to query. Returns: - dict[str:float] of complexity version to complexity. + Map of complexity version to complexity. """ state_events = await self.get_current_state_event_counts(room_id) @@ -1969,12 +1985,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_backward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_backward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question has any of its prev_events listed as a # backward extremity, it's next to a gap. @@ -2024,12 +2045,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_forward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_edges` and `event_forward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question is a forward extremity, we will just # consider any potential forward gap as not a gap since it's one of @@ -2110,13 +2136,33 @@ class EventsWorkerStore(SQLBaseStore): The closest event_id otherwise None if we can't find any event in the given direction. """ + if direction == "b": + # Find closest event *before* a given timestamp. We use descending + # (which gives values largest to smallest) because we want the + # largest possible timestamp *before* the given timestamp. + comparison_operator = "<=" + order = "DESC" + else: + # Find closest event *after* a given timestamp. We use ascending + # (which gives values smallest to largest) because we want the + # closest possible timestamp *after* the given timestamp. + comparison_operator = ">=" + order = "ASC" - sql_template = """ + sql_template = f""" SELECT event_id FROM events LEFT JOIN rejections USING (event_id) WHERE - origin_server_ts %s ? - AND room_id = ? + room_id = ? + AND origin_server_ts {comparison_operator} ? + /** + * Make sure the event isn't an `outlier` because we have no way + * to later check whether it's next to a gap. `outliers` do not + * have entries in the `event_edges`, `event_forward_extremeties`, + * and `event_backward_extremities` tables to check against + * (used by `is_event_next_to_backward_gap` and `is_event_next_to_forward_gap`). + */ + AND NOT outlier /* Make sure event is not rejected */ AND rejections.event_id IS NULL /** @@ -2126,27 +2172,14 @@ class EventsWorkerStore(SQLBaseStore): * Finally, we can tie-break based on when it was received on the server * (`stream_ordering`). */ - ORDER BY origin_server_ts %s, depth %s, stream_ordering %s + ORDER BY origin_server_ts {order}, depth {order}, stream_ordering {order} LIMIT 1; """ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: - if direction == "b": - # Find closest event *before* a given timestamp. We use descending - # (which gives values largest to smallest) because we want the - # largest possible timestamp *before* the given timestamp. - comparison_operator = "<=" - order = "DESC" - else: - # Find closest event *after* a given timestamp. We use ascending - # (which gives values smallest to largest) because we want the - # closest possible timestamp *after* the given timestamp. - comparison_operator = ">=" - order = "ASC" - txn.execute( - sql_template % (comparison_operator, order, order, order), - (timestamp, room_id), + sql_template, + (room_id, timestamp), ) row = txn.fetchone() if row: @@ -2200,7 +2233,15 @@ class EventsWorkerStore(SQLBaseStore): return result is not None async def get_partial_state_events_batch(self, room_id: str) -> List[str]: - """Get a list of events in the given room that have partial state""" + """ + Get a list of events in the given room that: + - have partial state; and + - are ready to be resynced (because they have no prev_events that are + partial-stated) + + See the docstring on `_get_partial_state_events_batch_txn` for more + information. + """ return await self.db_pool.runInteraction( "get_partial_state_events_batch", self._get_partial_state_events_batch_txn, diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index cb9ee08fa8..12f3b601f1 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py
@@ -24,7 +24,7 @@ from synapse.types import JsonDict from synapse.util.caches.descriptors import cached -class FilteringStore(SQLBaseStore): +class FilteringWorkerStore(SQLBaseStore): @cached(num_args=2) async def get_user_filter( self, user_localpart: str, filter_id: Union[int, str] @@ -46,6 +46,8 @@ class FilteringStore(SQLBaseStore): return db_to_json(def_json) + +class FilteringStore(FilteringWorkerStore): async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int: def_json = encode_canonical_json(user_filter) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index efd136a864..db9a24db5e 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -217,7 +217,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None: """ Args: - reserved_users (tuple): reserved users to preserve + reserved_users: reserved users to preserve """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) @@ -370,8 +370,8 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): should not appear in the MAU stats). Args: - txn (cursor): - user_id (str): user to add/update + txn: + user_id: user to add/update """ assert ( self._update_on_this_worker @@ -401,7 +401,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): add the user to the monthly active tables Args: - user_id(str): the user_id to query + user_id: the user_id to query """ assert ( self._update_on_this_worker diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 8295322b0e..d4c64c46ad 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import ( TYPE_CHECKING, Any, Collection, Dict, + Iterable, List, Mapping, Optional, @@ -30,7 +30,7 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -84,14 +84,15 @@ def _load_rules( push_rules = PushRules(ruleslist) filtered_rules = FilteredPushRules( - push_rules, enabled_map, msc3772_enabled=experimental_config.msc3772_enabled + push_rules, + enabled_map, + msc3664_enabled=experimental_config.msc3664_enabled, + msc1767_enabled=experimental_config.msc1767_enabled, ) return filtered_rules -# The ABCMeta metaclass ensures that it cannot be instantiated without -# the abstract methods being implemented. class PushRulesWorkerStore( ApplicationServiceWorkerStore, PusherWorkerStore, @@ -99,7 +100,6 @@ class PushRulesWorkerStore( ReceiptsWorkerStore, EventsWorkerStore, SQLBaseStore, - metaclass=abc.ABCMeta, ): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. @@ -113,14 +113,14 @@ class PushRulesWorkerStore( ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) - else: - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "push_rules_stream", + "stream_id", + is_writer=hs.config.worker.worker_app is None, + ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, @@ -136,14 +136,23 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) - @abc.abstractmethod def get_max_push_rules_stream_id(self) -> int: """Get the position of the push rules stream. Returns: int """ - raise NotImplementedError() + return self._push_rules_stream_id_gen.get_current_token() + + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: + if stream_name == PushRulesStream.NAME: + self._push_rules_stream_id_gen.advance(instance_name, token) + for row in rows: + self.get_push_rules_for_user.invalidate((row.user_id,)) + self.push_rules_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 01206950a9..40fd781a6a 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -27,13 +27,18 @@ from typing import ( ) from synapse.push import PusherConfig, ThrottleParams +from synapse.replication.tcp.streams import PushersStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, ) -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + AbstractStreamIdTracker, + StreamIdGenerator, +) from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -52,8 +57,15 @@ class PusherWorkerStore(SQLBaseStore): hs: "HomeServer", ): super().__init__(database, db_conn, hs) - self._pushers_id_gen = StreamIdGenerator( - db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] + + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + is_writer=hs.config.worker.worker_app is None, ) self.db_pool.updates.register_background_update_handler( @@ -96,6 +108,16 @@ class PusherWorkerStore(SQLBaseStore): yield PusherConfig(**r) + def get_pushers_stream_token(self) -> int: + return self._pushers_id_gen.get_current_token() + + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: + if stream_name == PushersStream.NAME: + self._pushers_id_gen.advance(instance_name, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) + async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str ) -> Iterator[PusherConfig]: @@ -303,14 +325,11 @@ class PusherWorkerStore(SQLBaseStore): async def set_throttle_params( self, pusher_id: str, room_id: str, params: ThrottleParams ) -> None: - # no need to lock because `pusher_throttle` has a primary key on - # (pusher, room_id) so simple_upsert will retry await self.db_pool.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms}, desc="set_throttle_params", - lock=False, ) async def _remove_deactivated_pushers(self, progress: dict, batch_size: int) -> int: @@ -545,8 +564,9 @@ class PusherBackgroundUpdatesStore(SQLBaseStore): class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): - def get_pushers_stream_token(self) -> int: - return self._pushers_id_gen.get_current_token() + # Because we have write access, this will be a StreamIdGenerator + # (see PusherWorkerStore.__init__) + _pushers_id_gen: AbstractStreamIdGenerator async def add_pusher( self, @@ -566,8 +586,6 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): device_id: Optional[str] = None, ) -> None: async with self._pushers_id_gen.get_next() as stream_id: - # no need to lock because `pushers` has a unique key on - # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, @@ -586,7 +604,6 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): "device_id": device_id, }, desc="add_pusher", - lock=False, ) user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate( diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 246f78ac1f..e06725f69c 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import EduTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -61,6 +60,9 @@ class ReceiptsWorkerStore(SQLBaseStore): hs: "HomeServer", ): self._instance_name = hs.get_instance_name() + + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): @@ -87,14 +89,12 @@ class ReceiptsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.receipts: - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - else: - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) + self._receipts_id_gen = StreamIdGenerator( + db_conn, + "receipts_linearized", + "stream_id", + is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, + ) super().__init__(database, db_conn, hs) @@ -113,24 +113,6 @@ class ReceiptsWorkerStore(SQLBaseStore): prefilled_cache=receipts_stream_prefill, ) - self.db_pool.updates.register_background_index_update( - "receipts_linearized_unique_index", - index_name="receipts_linearized_unique_index", - table="receipts_linearized", - columns=["room_id", "receipt_type", "user_id"], - where_clause="thread_id IS NULL", - unique=True, - ) - - self.db_pool.updates.register_background_index_update( - "receipts_graph_unique_index", - index_name="receipts_graph_unique_index", - table="receipts_graph", - columns=["room_id", "receipt_type", "user_id"], - where_clause="thread_id IS NULL", - unique=True, - ) - def get_max_receipt_stream_id(self) -> int: """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @@ -418,6 +400,8 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type[row["user_id"]] = db_to_json(row["data"]) + if row["thread_id"]: + receipt_type[row["user_id"]]["thread_id"] = row["thread_id"] results = { room_id: [results[room_id]] if room_id in results else [] @@ -700,9 +684,6 @@ class ReceiptsWorkerStore(SQLBaseStore): "data": json_encoder.encode(data), }, where_clause=where_clause, - # receipts_linearized has a unique constraint on - # (user_id, room_id, receipt_type), so no need to lock - lock=False, ) return rx_ts @@ -860,14 +841,13 @@ class ReceiptsWorkerStore(SQLBaseStore): "data": json_encoder.encode(data), }, where_clause=where_clause, - # receipts_graph has a unique constraint on - # (user_id, room_id, receipt_type), so no need to lock - lock=False, ) class ReceiptsBackgroundUpdateStore(SQLBaseStore): POPULATE_RECEIPT_EVENT_STREAM_ORDERING = "populate_event_stream_ordering" + RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME = "receipts_linearized_unique_index" + RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME = "receipts_graph_unique_index" def __init__( self, @@ -881,6 +861,14 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING, self._populate_receipt_event_stream_ordering, ) + self.db_pool.updates.register_background_update_handler( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME, + self._background_receipts_linearized_unique_index, + ) + self.db_pool.updates.register_background_update_handler( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME, + self._background_receipts_graph_unique_index, + ) async def _populate_receipt_event_stream_ordering( self, progress: JsonDict, batch_size: int @@ -936,6 +924,116 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): return batch_size + async def _background_receipts_linearized_unique_index( + self, progress: dict, batch_size: int + ) -> int: + """Removes duplicate receipts and adds a unique index on + `(room_id, receipt_type, user_id)` to `receipts_linearized`, for non-thread + receipts.""" + + def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: + # Identify any duplicate receipts arising from + # https://github.com/matrix-org/synapse/issues/14406. + # We expect the following query to use the per-thread receipt index and take + # less than a minute. + sql = """ + SELECT MAX(stream_id), room_id, receipt_type, user_id + FROM receipts_linearized + WHERE thread_id IS NULL + GROUP BY room_id, receipt_type, user_id + HAVING COUNT(*) > 1 + """ + txn.execute(sql) + duplicate_keys = cast(List[Tuple[int, str, str, str]], list(txn)) + + # Then remove duplicate receipts, keeping the one with the highest + # `stream_id`. There should only be a single receipt with any given + # `stream_id`. + for max_stream_id, room_id, receipt_type, user_id in duplicate_keys: + sql = """ + DELETE FROM receipts_linearized + WHERE + room_id = ? AND + receipt_type = ? AND + user_id = ? AND + thread_id IS NULL AND + stream_id < ? + """ + txn.execute(sql, (room_id, receipt_type, user_id, max_stream_id)) + + await self.db_pool.runInteraction( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME, + _remote_duplicate_receipts_txn, + ) + + await self.db_pool.updates.create_index_in_background( + index_name="receipts_linearized_unique_index", + table="receipts_linearized", + columns=["room_id", "receipt_type", "user_id"], + where_clause="thread_id IS NULL", + unique=True, + ) + + await self.db_pool.updates._end_background_update( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME + ) + + return 1 + + async def _background_receipts_graph_unique_index( + self, progress: dict, batch_size: int + ) -> int: + """Removes duplicate receipts and adds a unique index on + `(room_id, receipt_type, user_id)` to `receipts_graph`, for non-thread + receipts.""" + + def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: + # Identify any duplicate receipts arising from + # https://github.com/matrix-org/synapse/issues/14406. + # We expect the following query to use the per-thread receipt index and take + # less than a minute. + sql = """ + SELECT room_id, receipt_type, user_id FROM receipts_graph + WHERE thread_id IS NULL + GROUP BY room_id, receipt_type, user_id + HAVING COUNT(*) > 1 + """ + txn.execute(sql) + duplicate_keys = cast(List[Tuple[str, str, str]], list(txn)) + + # Then remove all duplicate receipts. + # We could be clever and try to keep the latest receipt out of every set of + # duplicates, but it's far simpler to remove them all. + for room_id, receipt_type, user_id in duplicate_keys: + sql = """ + DELETE FROM receipts_graph + WHERE + room_id = ? AND + receipt_type = ? AND + user_id = ? AND + thread_id IS NULL + """ + txn.execute(sql, (room_id, receipt_type, user_id)) + + await self.db_pool.runInteraction( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME, + _remote_duplicate_receipts_txn, + ) + + await self.db_pool.updates.create_index_in_background( + index_name="receipts_graph_unique_index", + table="receipts_graph", + columns=["room_id", "receipt_type", "user_id"], + where_clause="thread_id IS NULL", + unique=True, + ) + + await self.db_pool.updates._end_background_update( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME + ) + + return 1 + class ReceiptsStore(ReceiptsWorkerStore, ReceiptsBackgroundUpdateStore): pass diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 2996d6bb4d..31f0f2bd3d 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast import attr from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError +from synapse.api.errors import ( + Codes, + NotFoundError, + StoreError, + SynapseError, + ThreepidValidationError, +) from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import ( @@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception): because this external id is given to an other user.""" +class LoginTokenExpired(Exception): + """Exception if the login token sent expired""" + + +class LoginTokenReused(Exception): + """Exception if the login token sent was already used""" + + @attr.s(frozen=True, slots=True, auto_attribs=True) class TokenLookupResult: """Result of looking up an access token. @@ -115,6 +129,20 @@ class RefreshTokenLookupResult: If None, the session can be refreshed indefinitely.""" +@attr.s(auto_attribs=True, frozen=True, slots=True) +class LoginTokenLookupResult: + """Result of looking up a login token.""" + + user_id: str + """The user this token belongs to.""" + + auth_provider_id: Optional[str] + """The SSO Identity Provider that the user authenticated with, to get this token.""" + + auth_provider_session_id: Optional[str] + """The session ID advertised by the SSO Identity Provider.""" + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -925,7 +953,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """Returns user id from threepid Args: - txn (cursor): + txn: medium: threepid medium e.g. email address: threepid address e.g. me@example.com @@ -1255,8 +1283,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """Sets an expiration date to the account with the given user ID. Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be + user_id: User ID to set an expiration date for. + use_delta: If set to False, the expiration date for the user will be now + validity period. If set to True, this expiration date will be a random value in the [now + period - d ; now + period] range, d being a delta equal to 10% of the validity period. @@ -1789,6 +1817,130 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "replace_refresh_token", _replace_refresh_token_txn ) + async def add_login_token_to_user( + self, + user_id: str, + token: str, + expiry_ts: int, + auth_provider_id: Optional[str], + auth_provider_session_id: Optional[str], + ) -> None: + """Adds a short-term login token for the given user. + + Args: + user_id: The user ID. + token: The new login token to add. + expiry_ts (milliseconds since the epoch): Time after which the login token + cannot be used. + auth_provider_id: The SSO Identity Provider that the user authenticated with + to get this token, if any + auth_provider_session_id: The session ID advertised by the SSO Identity + Provider, if any. + """ + await self.db_pool.simple_insert( + "login_tokens", + { + "token": token, + "user_id": user_id, + "expiry_ts": expiry_ts, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + desc="add_login_token_to_user", + ) + + def _consume_login_token( + self, + txn: LoggingTransaction, + token: str, + ts: int, + ) -> LoginTokenLookupResult: + values = self.db_pool.simple_select_one_txn( + txn, + "login_tokens", + keyvalues={"token": token}, + retcols=( + "user_id", + "expiry_ts", + "used_ts", + "auth_provider_id", + "auth_provider_session_id", + ), + allow_none=True, + ) + + if values is None: + raise NotFoundError() + + self.db_pool.simple_update_one_txn( + txn, + "login_tokens", + keyvalues={"token": token}, + updatevalues={"used_ts": ts}, + ) + user_id = values["user_id"] + expiry_ts = values["expiry_ts"] + used_ts = values["used_ts"] + auth_provider_id = values["auth_provider_id"] + auth_provider_session_id = values["auth_provider_session_id"] + + # Token was already used + if used_ts is not None: + raise LoginTokenReused() + + # Token expired + if ts > int(expiry_ts): + raise LoginTokenExpired() + + return LoginTokenLookupResult( + user_id=user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + async def consume_login_token(self, token: str) -> LoginTokenLookupResult: + """Lookup a login token and consume it. + + Args: + token: The login token. + + Returns: + The data stored with that token, including the `user_id`. Returns `None` if + the token does not exist or if it expired. + + Raises: + NotFound if the login token was not found in database + LoginTokenExpired if the login token expired + LoginTokenReused if the login token was already used + """ + return await self.db_pool.runInteraction( + "consume_login_token", + self._consume_login_token, + token, + self._clock.time_msec(), + ) + + async def invalidate_login_tokens_by_session_id( + self, auth_provider_id: str, auth_provider_session_id: str + ) -> None: + """Invalidate login tokens with the given IdP session ID. + + Args: + auth_provider_id: The SSO Identity Provider that the user authenticated with + to get this token + auth_provider_session_id: The session ID advertised by the SSO Identity + Provider + """ + await self.db_pool.simple_update( + table="login_tokens", + keyvalues={ + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + updatevalues={"used_ts": self._clock.time_msec()}, + desc="invalidate_login_tokens_by_session_id", + ) + @cached() async def is_guest(self, user_id: str) -> bool: res = await self.db_pool.simple_select_one_onecol( @@ -2019,6 +2171,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): and hs.config.experimental.msc3866.require_approval_for_new_accounts ) + # Create a background job for removing expired login tokens + if hs.config.worker.run_background_tasks: + self._clock.looping_call( + self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS + ) + async def add_access_token_to_user( self, user_id: str, @@ -2617,6 +2775,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): approved, ) + @wrap_as_background_process("delete_expired_login_tokens") + async def _delete_expired_login_tokens(self) -> None: + """Remove login tokens with expiry dates that have passed.""" + + def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None: + sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?" + txn.execute(sql, (ts,)) + + # We keep the expired tokens for an extra 5 minutes so we can measure how many + # times a token is being used after its expiry + now = self._clock.time_msec() + await self.db_pool.runInteraction( + "delete_expired_login_tokens", + _delete_expired_login_tokens_txn, + now - (5 * 60 * 1000), + ) + def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 154385b1e8..aea96e9d24 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py
@@ -14,11 +14,13 @@ import logging from typing import ( + TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -28,19 +30,48 @@ from typing import ( import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadsNextBatch: + topological_ordering: int + stream_ordering: int + + def __str__(self) -> str: + return f"{self.topological_ordering}_{self.stream_ordering}" + + @classmethod + def from_string(cls, string: str) -> "ThreadsNextBatch": + """ + Creates a ThreadsNextBatch from its textual representation. + """ + try: + keys = (int(s) for s in string.split("_")) + return cls(*keys) + except Exception: + raise SynapseError(400, "Invalid threads token") + + +@attr.s(slots=True, frozen=True, auto_attribs=True) class _RelatedEvent: """ Contains enough information about a related event in order to properly filter @@ -51,11 +82,79 @@ class _RelatedEvent: event_id: str # The sender of the related event. sender: str - topological_ordering: Optional[int] - stream_ordering: int class RelationsWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + "threads_backfill", self._backfill_threads + ) + + async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int: + """Backfill the threads table.""" + + def threads_backfill_txn(txn: LoggingTransaction) -> int: + last_thread_id = progress.get("last_thread_id", "") + + # Get the latest event in each thread by topo ordering / stream ordering. + # + # Note that the MAX(event_id) is needed to abide by the rules of group by, + # but doesn't actually do anything since there should only be a single event + # ID per topo/stream ordering pair. + sql = f""" + SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id > ? AND + relation_type = '{RelationTypes.THREAD}' + GROUP BY room_id, relates_to_id + ORDER BY relates_to_id + LIMIT ? + """ + txn.execute(sql, (last_thread_id, batch_size)) + + # No more rows to process. + rows = txn.fetchall() + if not rows: + return 0 + + # Insert the rows into the threads table. If a matching thread already exists, + # assume it is from a newer event. + sql = """ + INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id) + VALUES %s + ON CONFLICT (room_id, thread_id) + DO NOTHING + """ + if isinstance(txn.database_engine, PostgresEngine): + txn.execute_values(sql % ("?",), rows, fetch=False) + else: + txn.execute_batch(sql % ("(?, ?, ?, ?, ?)",), rows) + + # Mark the progress. + self.db_pool.updates._background_update_progress_txn( + txn, "threads_backfill", {"last_thread_id": rows[-1][1]} + ) + + return txn.rowcount + + result = await self.db_pool.runInteraction( + "threads_backfill", threads_backfill_txn + ) + + if not result: + await self.db_pool.updates._end_background_update("threads_backfill") + + return result + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, @@ -145,13 +244,17 @@ class RelationsWorkerStore(SQLBaseStore): txn.execute(sql, where_args + [limit + 1]) events = [] - for event_id, relation_type, sender, topo_ordering, stream_ordering in txn: + topo_orderings: List[int] = [] + stream_orderings: List[int] = [] + for event_id, relation_type, sender, topo_ordering, stream_ordering in cast( + List[Tuple[str, str, str, int, int]], txn + ): # Do not include edits for redacted events as they leak event # content. if not is_redacted or relation_type != RelationTypes.REPLACE: - events.append( - _RelatedEvent(event_id, sender, topo_ordering, stream_ordering) - ) + events.append(_RelatedEvent(event_id, sender)) + topo_orderings.append(topo_ordering) + stream_orderings.append(stream_ordering) # If there are more events, generate the next pagination key from the # last event returned. @@ -160,9 +263,11 @@ class RelationsWorkerStore(SQLBaseStore): # Instead of using the last row (which tells us there is more # data), use the last row to be returned. events = events[:limit] + topo_orderings = topo_orderings[:limit] + stream_orderings = stream_orderings[:limit] - topo = events[-1].topological_ordering - token = events[-1].stream_ordering + topo = topo_orderings[-1] + token = stream_orderings[-1] if direction == "b": # Tokens are positions between events. # This token points *after* the last event in the chunk. @@ -195,6 +300,42 @@ class RelationsWorkerStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + async def get_all_relations_for_event_with_types( + self, + event_id: str, + relation_types: List[str], + ) -> List[str]: + """Get the event IDs of all events that have a relation to the given event with + one of the given relation types. + + Args: + event_id: The event for which to look for related events. + relation_types: The types of relations to look for. + + Returns: + A list of the IDs of the events that relate to the given event with one of + the given relation types. + """ + + def get_all_relation_ids_for_event_with_types_txn( + txn: LoggingTransaction, + ) -> List[str]: + rows = self.db_pool.simple_select_many_txn( + txn=txn, + table="event_relations", + column="relation_type", + iterable=relation_types, + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ) + + return [row["event_id"] for row in rows] + + return await self.db_pool.runInteraction( + desc="get_all_relation_ids_for_event_with_types", + func=get_all_relation_ids_for_event_with_types_txn, + ) + async def event_includes_relation(self, event_id: str) -> bool: """Check if the given event relates to another event. @@ -258,112 +399,196 @@ class RelationsWorkerStore(SQLBaseStore): ) return result is not None - @cached(tree=True) - async def get_aggregation_groups_for_event( - self, event_id: str, room_id: str, limit: int = 5 - ) -> List[JsonDict]: - """Get a list of annotations on the event, grouped by event type and + @cached() + async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_aggregation_groups_for_event", list_name="event_ids" + ) + async def get_aggregation_groups_for_events( + self, event_ids: Collection[str] + ) -> Mapping[str, Optional[List[JsonDict]]]: + """Get a list of annotations on the given events, grouped by event type and aggregation key, sorted by count. This is used e.g. to get the what and how many reactions have happend on an event. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. Returns: - List of groups of annotations that match. Each row is a dict with - `type`, `key` and `count` fields. + A map of event IDs to a list of groups of annotations that match. + Each entry is a dict with `type`, `key` and `count` fields. """ + # The number of entries to return per event ID. + limit = 5 - args = [ - event_id, - room_id, - RelationTypes.ANNOTATION, - limit, - ] + clause, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.ANNOTATION) - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? - GROUP BY relation_type, type, aggregation_key - ORDER BY COUNT(*) DESC - LIMIT ? + sql = f""" + SELECT + relates_to_id, + annotation.type, + aggregation_key, + COUNT(DISTINCT annotation.sender) + FROM events AS annotation + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = annotation.room_id + WHERE + {clause} + AND relation_type = ? + GROUP BY relates_to_id, annotation.type, aggregation_key + ORDER BY relates_to_id, COUNT(*) DESC """ - def _get_aggregation_groups_for_event_txn( + def _get_aggregation_groups_for_events_txn( txn: LoggingTransaction, - ) -> List[JsonDict]: + ) -> Mapping[str, List[JsonDict]]: txn.execute(sql, args) - return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] + result: Dict[str, List[JsonDict]] = {} + for event_id, type, key, count in cast( + List[Tuple[str, str, str, int]], txn + ): + event_results = result.setdefault(event_id, []) + + # Limit the number of results per event ID. + if len(event_results) == limit: + continue + + event_results.append({"type": type, "key": key, "count": count}) + + return result return await self.db_pool.runInteraction( - "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn ) async def get_aggregation_groups_for_users( - self, - event_id: str, - room_id: str, - limit: int, - users: FrozenSet[str] = frozenset(), - ) -> Dict[Tuple[str, str], int]: + self, event_ids: Collection[str], users: FrozenSet[str] + ) -> Dict[str, Dict[Tuple[str, str], int]]: """Fetch the partial aggregations for an event for specific users. This is used, in conjunction with get_aggregation_groups_for_event, to remove information from the results for ignored users. Args: - event_id: Fetch events that relate to this event ID. - room_id: The room the event belongs to. - limit: Only fetch the `limit` groups. + event_ids: Fetch events that relate to these event IDs. users: The users to fetch information for. Returns: - A map of (event type, aggregation key) to a count of users. + A map of event ID to a map of (event type, aggregation key) to a + count of users. """ if not users: return {} - args: List[Union[str, int]] = [ - event_id, - room_id, - RelationTypes.ANNOTATION, - ] + events_sql, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) users_sql, users_args = make_in_list_sql_clause( - self.database_engine, "sender", users + self.database_engine, "annotation.sender", users ) args.extend(users_args) + args.append(RelationTypes.ANNOTATION) sql = f""" - SELECT type, aggregation_key, COUNT(DISTINCT sender) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql} - GROUP BY relation_type, type, aggregation_key - ORDER BY COUNT(*) DESC - LIMIT ? + SELECT + relates_to_id, + annotation.type, + aggregation_key, + COUNT(DISTINCT annotation.sender) + FROM events AS annotation + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = annotation.room_id + WHERE {events_sql} AND {users_sql} AND relation_type = ? + GROUP BY relates_to_id, annotation.type, aggregation_key + ORDER BY relates_to_id, COUNT(*) DESC """ def _get_aggregation_groups_for_users_txn( txn: LoggingTransaction, - ) -> Dict[Tuple[str, str], int]: - txn.execute(sql, args + [limit]) + ) -> Dict[str, Dict[Tuple[str, str], int]]: + txn.execute(sql, args) - return {(row[0], row[1]): row[2] for row in txn} + result: Dict[str, Dict[Tuple[str, str], int]] = {} + for event_id, type, key, count in cast( + List[Tuple[str, str, str, int]], txn + ): + result.setdefault(event_id, {})[(type, key)] = count + + return result return await self.db_pool.runInteraction( "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn ) @cached() + async def get_references_for_event(self, event_id: str) -> List[JsonDict]: + raise NotImplementedError() + + @cachedList(cached_method_name="get_references_for_event", list_name="event_ids") + async def get_references_for_events( + self, event_ids: Collection[str] + ) -> Mapping[str, Optional[List[_RelatedEvent]]]: + """Get a list of references to the given events. + + Args: + event_ids: Fetch events that relate to these event IDs. + + Returns: + A map of event IDs to a list of related event IDs (and their senders). + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.REFERENCE) + + sql = f""" + SELECT relates_to_id, ref.event_id, ref.sender + FROM events AS ref + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = ref.room_id + WHERE + {clause} + AND relation_type = ? + ORDER BY ref.topological_ordering, ref.stream_ordering + """ + + def _get_references_for_events_txn( + txn: LoggingTransaction, + ) -> Mapping[str, List[_RelatedEvent]]: + txn.execute(sql, args) + + result: Dict[str, List[_RelatedEvent]] = {} + for relates_to_id, event_id, sender in cast( + List[Tuple[str, str, str]], txn + ): + result.setdefault(relates_to_id, []).append( + _RelatedEvent(event_id, sender) + ) + + return result + + return await self.db_pool.runInteraction( + "_get_references_for_events_txn", _get_references_for_events_txn + ) + + @cached() def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() @@ -779,95 +1004,194 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - @cached(iterable=True) - async def get_mutual_event_relations_for_rel_type( - self, event_id: str, relation_type: str - ) -> Set[Tuple[str, str]]: - raise NotImplementedError() - - @cachedList( - cached_method_name="get_mutual_event_relations_for_rel_type", - list_name="relation_types", - ) - async def get_mutual_event_relations( - self, event_id: str, relation_types: Collection[str] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. + @cached(tree=True) + async def get_threads( + self, + room_id: str, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + """Get a list of thread IDs, ordered by topological ordering of their + latest reply. Args: - event_id: The event ID which is targeted by relations. - relation_types: The relation types to check for mutual relations. + room_id: The room the event belongs to. + limit: Only fetch the most recent `limit` threads. + from_token: Fetch rows from a previous next_batch, or from the start if None. Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type + A tuple of: + A list of thread root event IDs. + + The next_batch, if one exists. """ - rel_type_sql, rel_type_args = make_in_list_sql_clause( - self.database_engine, "relation_type", relation_types - ) + # Generate the pagination clause, if necessary. + # + # Find any threads where the latest reply is equal / before the last + # thread's topo ordering and earlier in stream ordering. + pagination_clause = "" + pagination_args: tuple = () + if from_token: + pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?" + pagination_args = ( + from_token.topological_ordering, + from_token.stream_ordering, + ) sql = f""" - SELECT DISTINCT relation_type, sender, type FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND {rel_type_sql} + SELECT thread_id, topological_ordering, stream_ordering + FROM threads + WHERE + room_id = ? + {pagination_clause} + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT ? """ - def _get_event_relations( + def _get_threads_txn( txn: LoggingTransaction, - ) -> Dict[str, Set[Tuple[str, str]]]: - txn.execute(sql, [event_id] + rel_type_args) - result: Dict[str, Set[Tuple[str, str]]] = { - rel_type: set() for rel_type in relation_types - } - for rel_type, sender, type in txn.fetchall(): - result[rel_type].add((sender, type)) - return result + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + txn.execute(sql, (room_id, *pagination_args, limit + 1)) - return await self.db_pool.runInteraction( - "get_event_relations", _get_event_relations - ) + rows = cast(List[Tuple[str, int, int]], txn.fetchall()) + thread_ids = [r[0] for r in rows] + + # If there are more events, generate the next pagination key from the + # last thread which will be returned. + next_token = None + if len(thread_ids) > limit: + last_topo_id = rows[-2][1] + last_stream_id = rows[-2][2] + next_token = ThreadsNextBatch(last_topo_id, last_stream_id) + + return thread_ids[:limit], next_token + + return await self.db_pool.runInteraction("get_threads", _get_threads_txn) @cached() - async def get_thread_id(self, event_id: str) -> Optional[str]: + async def get_thread_id(self, event_id: str) -> str: """ Get the thread ID for an event. This considers multi-level relations, e.g. an annotation to an event which is part of a thread. + It only searches up the relations tree, i.e. it only searches for events + which the given event is related to (and which those events are related + to, etc.) + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id(X) considers events B and C as part of thread A. + + See also get_thread_id_for_receipts. + Args: event_id: The event ID to fetch the thread ID for. Returns: The event ID of the root event in the thread, if this event is part - of a thread. None, otherwise. + of a thread. "main", otherwise. """ - # Since event relations form a tree, we should only ever find 0 or 1 - # results from the below query. + + # Recurse event relations up to the *root* event, then search that chain + # of relations for a thread relation. If one is found, the root event is + # returned. + # + # Note that this should only ever find 0 or 1 entries since it is invalid + # for an event to have a thread relation to an event which also has a + # relation. sql = """ WITH RECURSIVE related_events AS ( - SELECT event_id, relates_to_id, relation_type + SELECT event_id, relates_to_id, relation_type, 0 depth FROM event_relations WHERE event_id = ? - UNION SELECT e.event_id, e.relates_to_id, e.relation_type + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 FROM event_relations e INNER JOIN related_events r ON r.relates_to_id = e.event_id - ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + WHERE relation_type = 'm.thread' + ORDER BY depth DESC + LIMIT 1; """ - def _get_thread_id(txn: LoggingTransaction) -> Optional[str]: + def _get_thread_id(txn: LoggingTransaction) -> str: txn.execute(sql, (event_id,)) - # TODO Should we ensure there's only a single result here? row = txn.fetchone() if row: return row[0] - return None + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + @cached() + async def get_thread_id_for_receipts(self, event_id: str) -> str: + """ + Get the thread ID for an event by traversing to the top-most related event + and confirming any children events form a thread. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part + of thread A. + + See also get_thread_id. + + Args: + event_id: The event ID to fetch the thread ID for. + + Returns: + The event ID of the root event in the thread, if this event is part + of a thread. "main", otherwise. + """ + + # Recurse event relations up to the *root* event, then search for any events + # related to that root node for a thread relation. If one is found, the + # root event is returned. + # + # Note that there cannot be thread relations in the middle of the chain since + # it is invalid for an event to have a thread relation to an event which also + # has a relation. + sql = """ + SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE(( + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type, 0 depth + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + ORDER BY depth DESC + LIMIT 1 + ), ?) AND relation_type = 'm.thread' LIMIT 1; + """ + + def _get_related_thread_id(txn: LoggingTransaction) -> str: + txn.execute(sql, (event_id, event_id)) + row = txn.fetchone() + if row: + return row[0] + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE + + return await self.db_pool.runInteraction( + "get_related_thread_id", _get_related_thread_id + ) + class RelationsStore(RelationsWorkerStore): pass diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index e41c99027a..78906a5e1d 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -1,5 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019, 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -50,8 +50,14 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import IdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + IdGenerator, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.types import JsonDict, RetentionPolicy, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -97,6 +103,12 @@ class RoomSortOrder(Enum): STATE_EVENTS = "state_events" +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PartialStateResyncInfo: + joined_via: Optional[str] + servers_in_room: List[str] = attr.ib(factory=list) + + class RoomWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -108,6 +120,26 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): self.config: HomeServerConfig = hs.config + self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator + + if isinstance(database.engine, PostgresEngine): + self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="un_partial_stated_room_stream", + instance_name=self._instance_name, + tables=[ + ("un_partial_stated_room_stream", "instance_name", "stream_id") + ], + sequence_name="un_partial_stated_room_stream_sequence", + # TODO(faster_joins, multiple writers) Support multiple writers. + writers=["master"], + ) + else: + self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator( + db_conn, "un_partial_stated_room_stream", "stream_id" + ) + async def store_room( self, room_id: str, @@ -906,7 +938,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): event_json = db_to_json(content_json) content = event_json["content"] content_url = content.get("url") - thumbnail_url = content.get("info", {}).get("thumbnail_url") + info = content.get("info") + if isinstance(info, dict): + thumbnail_url = info.get("thumbnail_url") + else: + thumbnail_url = None for url in (content_url, thumbnail_url): if not url: @@ -1160,17 +1196,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): desc="get_partial_state_servers_at_join", ) - async def get_partial_state_rooms_and_servers( + async def get_partial_state_room_resync_info( self, - ) -> Mapping[str, Collection[str]]: - """Get all rooms containing events with partial state, and the servers known - to be in the room. + ) -> Mapping[str, PartialStateResyncInfo]: + """Get all rooms containing events with partial state, and the information + needed to restart a "resync" of those rooms. Returns: A dictionary of rooms with partial state, with room IDs as keys and lists of servers in rooms as values. """ - room_servers: Dict[str, List[str]] = {} + room_servers: Dict[str, PartialStateResyncInfo] = {} + + rows = await self.db_pool.simple_select_list( + table="partial_state_rooms", + keyvalues={}, + retcols=("room_id", "joined_via"), + desc="get_server_which_served_partial_join", + ) + + for row in rows: + room_id = row["room_id"] + joined_via = row["joined_via"] + room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via) rows = await self.db_pool.simple_select_list( "partial_state_rooms_servers", @@ -1182,74 +1230,18 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): for row in rows: room_id = row["room_id"] server_name = row["server_name"] - room_servers.setdefault(room_id, []).append(server_name) + entry = room_servers.get(room_id) + if entry is None: + # There is a foreign key constraint which enforces that every room_id in + # partial_state_rooms_servers appears in partial_state_rooms. So we + # expect `entry` to be non-null. (This reasoning fails if we've + # partial-joined between the two SELECTs, but this is unlikely to happen + # in practice.) + continue + entry.servers_in_room.append(server_name) return room_servers - async def clear_partial_state_room(self, room_id: str) -> bool: - """Clears the partial state flag for a room. - - Args: - room_id: The room whose partial state flag is to be cleared. - - Returns: - `True` if the partial state flag has been cleared successfully. - - `False` if the partial state flag could not be cleared because the room - still contains events with partial state. - """ - try: - await self.db_pool.runInteraction( - "clear_partial_state_room", self._clear_partial_state_room_txn, room_id - ) - return True - except self.db_pool.engine.module.IntegrityError as e: - # Assume that any `IntegrityError`s are due to partial state events. - logger.info( - "Exception while clearing lazy partial-state-room %s, retrying: %s", - room_id, - e, - ) - return False - - def _clear_partial_state_room_txn( - self, txn: LoggingTransaction, room_id: str - ) -> None: - DatabasePool.simple_delete_txn( - txn, - table="partial_state_rooms_servers", - keyvalues={"room_id": room_id}, - ) - DatabasePool.simple_delete_one_txn( - txn, - table="partial_state_rooms", - keyvalues={"room_id": room_id}, - ) - self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) - self._invalidate_cache_and_stream( - txn, self.get_partial_state_servers_at_join, (room_id,) - ) - - # We now delete anything from `device_lists_remote_pending` with a - # stream ID less than the minimum - # `partial_state_rooms.device_lists_stream_id`, as we no longer need them. - device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn( - txn, - table="partial_state_rooms", - keyvalues={}, - retcol="MIN(device_lists_stream_id)", - allow_none=True, - ) - if device_lists_stream_id is None: - # There are no rooms being currently partially joined, so we delete everything. - txn.execute("DELETE FROM device_lists_remote_pending") - else: - sql = """ - DELETE FROM device_lists_remote_pending - WHERE stream_id <= ? - """ - txn.execute(sql, (device_lists_stream_id,)) - @cached() async def is_partial_state_room(self, room_id: str) -> bool: """Checks if this room has partial state. @@ -1285,6 +1277,66 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) return result["join_event_id"], result["device_lists_stream_id"] + def get_un_partial_stated_rooms_token(self) -> int: + # TODO(faster_joins, multiple writers): This is inappropriate if there + # are multiple writers because workers that don't write often will + # hold all readers up. + # (See `MultiWriterIdGenerator.get_persisted_upto_position` for an + # explanation.) + return self._un_partial_stated_rooms_stream_id_gen.get_current_token() + + async def get_un_partial_stated_rooms_from_stream( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: + """Get updates for caches replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_un_partial_stated_rooms_from_stream_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: + sql = """ + SELECT stream_id, room_id + FROM un_partial_stated_room_stream + WHERE ? < stream_id AND stream_id <= ? AND instance_name = ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, instance_name, limit)) + updates = [(row[0], (row[1],)) for row in txn] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( + "get_un_partial_stated_rooms_from_stream", + get_un_partial_stated_rooms_from_stream_txn, + ) + class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -1776,6 +1828,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") + self._instance_name = hs.get_instance_name() + async def upsert_room_on_join( self, room_id: str, room_version: RoomVersion, state_events: List[EventBase] ) -> None: @@ -1817,9 +1871,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): "creator": room_creator, "has_auth_chain_index": has_auth_chain_index, }, - # rooms has a unique constraint on room_id, so no need to lock when doing an - # emulated upsert. - lock=False, ) async def store_partial_state_room( @@ -1827,6 +1878,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id: str, servers: Collection[str], device_lists_stream_id: int, + joined_via: str, ) -> None: """Mark the given room as containing events with partial state. @@ -1842,6 +1894,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): servers: other servers known to be in the room device_lists_stream_id: the device_lists stream ID at the time when we first joined the room. + joined_via: the server name we requested a partial join from. """ await self.db_pool.runInteraction( "store_partial_state_room", @@ -1849,6 +1902,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id, servers, device_lists_stream_id, + joined_via, ) def _store_partial_state_room_txn( @@ -1857,6 +1911,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id: str, servers: Collection[str], device_lists_stream_id: int, + joined_via: str, ) -> None: DatabasePool.simple_insert_txn( txn, @@ -1866,6 +1921,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): "device_lists_stream_id": device_lists_stream_id, # To be updated later once the join event is persisted. "join_event_id": None, + "joined_via": joined_via, }, ) DatabasePool.simple_insert_many_txn( @@ -1935,9 +1991,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): "creator": "", "has_auth_chain_index": has_auth_chain_index, }, - # rooms has a unique constraint on room_id, so no need to lock when doing an - # emulated upsert. - lock=False, ) async def set_room_is_public(self, room_id: str, is_public: bool) -> None: @@ -2026,7 +2079,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): Args: report_id: ID of reported event in database Returns: - event_report: json list of information from event report + JSON dict of information from an event report or None if the + report does not exist. """ def _get_event_report_txn( @@ -2099,8 +2153,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): user_id: search for user_id. Ignored if user_id is None room_id: search for room_id. Ignored if room_id is None Returns: - event_reports: json list of event reports - count: total number of event reports matching the filter criteria + Tuple of: + json list of event reports + total number of event reports matching the filter criteria """ def _get_event_reports_paginate_txn( @@ -2239,3 +2294,84 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self.is_room_blocked, (room_id,), ) + + async def clear_partial_state_room(self, room_id: str) -> bool: + """Clears the partial state flag for a room. + + Args: + room_id: The room whose partial state flag is to be cleared. + + Returns: + `True` if the partial state flag has been cleared successfully. + + `False` if the partial state flag could not be cleared because the room + still contains events with partial state. + """ + try: + async with self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id: + await self.db_pool.runInteraction( + "clear_partial_state_room", + self._clear_partial_state_room_txn, + room_id, + un_partial_state_room_stream_id, + ) + return True + except self.db_pool.engine.module.IntegrityError as e: + # Assume that any `IntegrityError`s are due to partial state events. + logger.info( + "Exception while clearing lazy partial-state-room %s, retrying: %s", + room_id, + e, + ) + return False + + def _clear_partial_state_room_txn( + self, + txn: LoggingTransaction, + room_id: str, + un_partial_state_room_stream_id: int, + ) -> None: + DatabasePool.simple_delete_txn( + txn, + table="partial_state_rooms_servers", + keyvalues={"room_id": room_id}, + ) + DatabasePool.simple_delete_one_txn( + txn, + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + ) + self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) + self._invalidate_cache_and_stream( + txn, self.get_partial_state_servers_at_join, (room_id,) + ) + + DatabasePool.simple_insert_txn( + txn, + "un_partial_stated_room_stream", + { + "stream_id": un_partial_state_room_stream_id, + "instance_name": self._instance_name, + "room_id": room_id, + }, + ) + + # We now delete anything from `device_lists_remote_pending` with a + # stream ID less than the minimum + # `partial_state_rooms.device_lists_stream_id`, as we no longer need them. + device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn( + txn, + table="partial_state_rooms", + keyvalues={}, + retcol="MIN(device_lists_stream_id)", + allow_none=True, + ) + if device_lists_stream_id is None: + # There are no rooms being currently partially joined, so we delete everything. + txn.execute("DELETE FROM device_lists_remote_pending") + else: + sql = """ + DELETE FROM device_lists_remote_pending + WHERE stream_id <= ? + """ + txn.execute(sql, (device_lists_stream_id,)) diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
index 39e80f6f5b..131f357d04 100644 --- a/synapse/storage/databases/main/room_batch.py +++ b/synapse/storage/databases/main/room_batch.py
@@ -44,6 +44,4 @@ class RoomBatchStore(SQLBaseStore): table="event_to_state_groups", keyvalues={"event_id": event_id}, values={"state_group": state_group_id, "event_id": event_id}, - # Unique constraint on event_id so we don't have to lock - lock=False, ) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 2337289d88..f02c1d7ea7 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -152,6 +152,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): the forward extremities of those rooms will exclude most members. We may also calculate room state incorrectly for such rooms and believe that a member is or is not in the room when the opposite is true. + + Note: If you only care about users in the room local to the homeserver, use + `get_local_users_in_room(...)` instead which will be more performant. """ return await self.db_pool.simple_select_onecol( table="current_state_events", @@ -666,7 +669,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): cached_method_name="get_rooms_for_user", list_name="user_ids", ) - async def get_rooms_for_users( + async def _get_rooms_for_users( self, user_ids: Collection[str] ) -> Dict[str, FrozenSet[str]]: """A batched version of `get_rooms_for_user`. @@ -697,6 +700,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): return {key: frozenset(rooms) for key, rooms in user_rooms.items()} + async def get_rooms_for_users( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[str]]: + """A batched wrapper around `_get_rooms_for_users`, to prevent locking + other calls to `get_rooms_for_user` for large user lists. + """ + all_user_rooms: Dict[str, FrozenSet[str]] = {} + + # 250 users is pretty arbitrary but the data can be quite large if users + # are in many rooms. + for batch_user_ids in batch_iter(user_ids, 250): + all_user_rooms.update(await self._get_rooms_for_users(batch_user_ids)) + + return all_user_rooms + @cached(max_entries=10000) async def does_pair_of_users_share_a_room( self, user_id: str, other_user_id: str @@ -727,7 +745,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # user and the set of other users, and then checking if there is any # overlap. sql = f""" - SELECT b.state_key + SELECT DISTINCT b.state_key FROM ( SELECT room_id FROM current_state_events WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ? @@ -736,7 +754,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): SELECT room_id, state_key FROM current_state_events WHERE type = 'm.room.member' AND membership = 'join' AND {clause} ) AS b using (room_id) - LIMIT 1 """ txn.execute(sql, (user_id, *args)) @@ -1500,6 +1517,36 @@ class RoomMemberStore( await self.db_pool.runInteraction("forget_membership", f) +def extract_heroes_from_room_summary( + details: Mapping[str, MemberSummary], me: str +) -> List[str]: + """Determine the users that represent a room, from the perspective of the `me` user. + + The rules which say which users we select are specified in the "Room Summary" + section of + https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3sync + + Returns a list (possibly empty) of heroes' mxids. + """ + empty_ms = MemberSummary([], 0) + + joined_user_ids = [ + r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me + ] + invited_user_ids = [ + r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me + ] + gone_user_ids = [ + r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me + ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me] + + # FIXME: order by stream ordering rather than as returned by SQL + if joined_user_ids or invited_user_ids: + return sorted(joined_user_ids + invited_user_ids)[0:5] + else: + return sorted(gone_user_ids)[0:5] + + @attr.s(slots=True, auto_attribs=True) class _JoinedHostsCache: """The cached data used by the `_get_joined_hosts_cache`.""" diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 1b79acf955..3fe433f66c 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py
@@ -11,10 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import logging import re -from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple +from collections import deque +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import attr @@ -27,7 +39,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import JsonDict if TYPE_CHECKING: @@ -68,11 +80,11 @@ class SearchWorkerStore(SQLBaseStore): if not self.hs.config.server.enable_search: return if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" - ) + sql = """ + INSERT INTO event_search + (event_id, room_id, key, vector, stream_ordering, origin_server_ts) + VALUES (?,?,?,to_tsvector('english', ?),?,?) + """ args1 = ( ( @@ -89,20 +101,20 @@ class SearchWorkerStore(SQLBaseStore): txn.execute_batch(sql, args1) elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - args2 = ( - ( - entry.event_id, - entry.room_id, - entry.key, - _clean_value_for_search(entry.value), - ) - for entry in entries + self.db_pool.simple_insert_many_txn( + txn, + table="event_search", + keys=("event_id", "room_id", "key", "value"), + values=( + ( + entry.event_id, + entry.room_id, + entry.key, + _clean_value_for_search(entry.value), + ) + for entry in entries + ), ) - txn.execute_batch(sql, args2) else: # This should be unreachable. @@ -150,15 +162,17 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): TYPES = ["m.room.name", "m.room.message", "m.room.topic"] def reindex_search_txn(txn: LoggingTransaction) -> int: - sql = ( - "SELECT stream_ordering, event_id, room_id, type, json, " - " origin_server_ts FROM events" - " JOIN event_json USING (room_id, event_id)" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " AND (%s)" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),) + sql = """ + SELECT stream_ordering, event_id, room_id, type, json, origin_server_ts + FROM events + JOIN event_json USING (room_id, event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND (%s) + ORDER BY stream_ordering DESC + LIMIT ? + """ % ( + " OR ".join("type = '%s'" % (t,) for t in TYPES), + ) txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) @@ -272,8 +286,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): try: c.execute( - "CREATE INDEX CONCURRENTLY event_search_fts_idx" - " ON event_search USING GIN (vector)" + """ + CREATE INDEX CONCURRENTLY event_search_fts_idx + ON event_search USING GIN (vector) + """ ) except psycopg2.ProgrammingError as e: logger.warning( @@ -311,12 +327,16 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): # We create with NULLS FIRST so that when we search *backwards* # we get the ones with non null origin_server_ts *first* c.execute( - "CREATE INDEX CONCURRENTLY event_search_room_order ON event_search(" - "room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + """ + CREATE INDEX CONCURRENTLY event_search_room_order + ON event_search(room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST) + """ ) c.execute( - "CREATE INDEX CONCURRENTLY event_search_order ON event_search(" - "origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + """ + CREATE INDEX CONCURRENTLY event_search_order + ON event_search(origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST) + """ ) conn.set_session(autocommit=False) @@ -333,14 +353,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): ) def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]: - sql = ( - "UPDATE event_search AS es SET stream_ordering = e.stream_ordering," - " origin_server_ts = e.origin_server_ts" - " FROM events AS e" - " WHERE e.event_id = es.event_id" - " AND ? <= e.stream_ordering AND e.stream_ordering < ?" - " RETURNING es.stream_ordering" - ) + sql = """ + UPDATE event_search AS es + SET stream_ordering = e.stream_ordering, origin_server_ts = e.origin_server_ts + FROM events AS e + WHERE e.event_id = es.event_id + AND ? <= e.stream_ordering AND e.stream_ordering < ? + RETURNING es.stream_ordering + """ min_stream_id = max_stream_id - batch_size txn.execute(sql, (min_stream_id, max_stream_id)) @@ -421,8 +441,6 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = _parse_query(self.database_engine, search_term) - args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. @@ -444,32 +462,35 @@ class SearchStore(SearchBackgroundUpdateStore): count_clauses = clauses if isinstance(self.database_engine, PostgresEngine): - sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," - " room_id, event_id" - " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" - ) + search_query = search_term + sql = """ + SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) AS rank, + room_id, event_id + FROM event_search + WHERE vector @@ websearch_to_tsquery('english', ?) + """ args = [search_query, search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" - ) + count_sql = """ + SELECT room_id, count(*) as count FROM event_search + WHERE vector @@ websearch_to_tsquery('english', ?) + """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" - " FROM event_search" - " WHERE value MATCH ?" - ) + search_query = _parse_query_for_sqlite(search_term) + + sql = """ + SELECT rank(matchinfo(event_search)) as rank, room_id, event_id + FROM event_search + WHERE value MATCH ? + """ args = [search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ?" - ) - count_args = [search_term] + count_args + count_sql = """ + SELECT room_id, count(*) as count FROM event_search + WHERE value MATCH ? + """ + count_args = [search_query] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -510,7 +531,6 @@ class SearchStore(SearchBackgroundUpdateStore): ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) - return { "results": [ {"event": event_map[r["event_id"]], "rank": r["rank"]} @@ -542,9 +562,6 @@ class SearchStore(SearchBackgroundUpdateStore): Each match as a dictionary. """ clauses = [] - - search_query = _parse_query(self.database_engine, search_term) - args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. @@ -576,26 +593,29 @@ class SearchStore(SearchBackgroundUpdateStore): raise SynapseError(400, "Invalid pagination token") clauses.append( - "(origin_server_ts < ?" - " OR (origin_server_ts = ? AND stream_ordering < ?))" + """ + (origin_server_ts < ? OR (origin_server_ts = ? AND stream_ordering < ?)) + """ ) args.extend([origin_server_ts, origin_server_ts, stream]) if isinstance(self.database_engine, PostgresEngine): - sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," - " origin_server_ts, stream_ordering, room_id, event_id" - " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " - ) + search_query = search_term + sql = """ + SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank, + origin_server_ts, stream_ordering, room_id, event_id + FROM event_search + WHERE vector @@ websearch_to_tsquery('english', ?) AND + """ args = [search_query, search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " - ) + count_sql = """ + SELECT room_id, count(*) as count FROM event_search + WHERE vector @@ websearch_to_tsquery('english', ?) AND + """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): + # We use CROSS JOIN here to ensure we use the right indexes. # https://sqlite.org/optoverview.html#crossjoin # @@ -604,23 +624,25 @@ class SearchStore(SearchBackgroundUpdateStore): # in the events table to get the topological ordering. We need # to use the indexes in this order because sqlite refuses to # MATCH unless it uses the full text search index - sql = ( - "SELECT rank(matchinfo) as rank, room_id, event_id," - " origin_server_ts, stream_ordering" - " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" - " FROM event_search" - " WHERE value MATCH ?" - " )" - " CROSS JOIN events USING (event_id)" - " WHERE " + sql = """ + SELECT + rank(matchinfo) as rank, room_id, event_id, origin_server_ts, stream_ordering + FROM ( + SELECT key, event_id, matchinfo(event_search) as matchinfo + FROM event_search + WHERE value MATCH ? ) + CROSS JOIN events USING (event_id) + WHERE + """ + search_query = _parse_query_for_sqlite(search_term) args = [search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ? AND " - ) - count_args = [search_term] + count_args + count_sql = """ + SELECT room_id, count(*) as count FROM event_search + WHERE value MATCH ? AND + """ + count_args = [search_query] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -631,10 +653,10 @@ class SearchStore(SearchBackgroundUpdateStore): # We add an arbitrary limit here to ensure we don't try to pull the # entire table from the database. if isinstance(self.database_engine, PostgresEngine): - sql += ( - " ORDER BY origin_server_ts DESC NULLS LAST," - " stream_ordering DESC NULLS LAST LIMIT ?" - ) + sql += """ + ORDER BY origin_server_ts DESC NULLS LAST, stream_ordering DESC NULLS LAST + LIMIT ? + """ elif isinstance(self.database_engine, Sqlite3Engine): sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?" else: @@ -729,13 +751,16 @@ class SearchStore(SearchBackgroundUpdateStore): while stop_sel in value: stop_sel += ">" - query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( - _to_postgres_options( - { - "StartSel": start_sel, - "StopSel": stop_sel, - "MaxFragments": "50", - } + query = ( + "SELECT ts_headline(?, websearch_to_tsquery('english', ?), %s)" + % ( + _to_postgres_options( + { + "StartSel": start_sel, + "StopSel": stop_sel, + "MaxFragments": "50", + } + ) ) ) txn.execute(query, (value, search_query)) @@ -760,20 +785,127 @@ def _to_postgres_options(options_dict: JsonDict) -> str: return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) -def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str: - """Takes a plain unicode string from the user and converts it into a form - that can be passed to database. - We use this so that we can add prefix matching, which isn't something - that is supported by default. +@dataclass +class Phrase: + phrase: List[str] + + +class SearchToken(enum.Enum): + Not = enum.auto() + Or = enum.auto() + And = enum.auto() + + +Token = Union[str, Phrase, SearchToken] +TokenList = List[Token] + + +def _is_stop_word(word: str) -> bool: + # TODO Pull these out of the dictionary: + # https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop + return word in {"the", "a", "you", "me", "and", "but"} + + +def _tokenize_query(query: str) -> TokenList: """ + Convert the user-supplied `query` into a TokenList, which can be translated into + some DB-specific syntax. + + The following constructs are supported: + + - phrase queries using "double quotes" + - case-insensitive `or` and `and` operators + - negation of a keyword via unary `-` + - unary hyphen to denote NOT e.g. 'include -exclude' - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + The following differs from websearch_to_tsquery: - if isinstance(database_engine, PostgresEngine): - return " & ".join(result + ":*" for result in results) - elif isinstance(database_engine, Sqlite3Engine): - return " & ".join(result + "*" for result in results) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") + - Stop words are not removed. + - Unclosed phrases are treated differently. + + """ + tokens: TokenList = [] + + # Find phrases. + in_phrase = False + parts = deque(query.split('"')) + for i, part in enumerate(parts): + # The contents inside double quotes is treated as a phrase. + in_phrase = bool(i % 2) + + # Pull out the individual words, discarding any non-word characters. + words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE)) + + # Phrases have simplified handling of words. + if in_phrase: + # Skip stop words. + phrase = [word for word in words if not _is_stop_word(word)] + + # Consecutive words are implicitly ANDed together. + if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or): + tokens.append(SearchToken.And) + + # Add the phrase. + tokens.append(Phrase(phrase)) + continue + + # Otherwise, not in a phrase. + while words: + word = words.popleft() + + if word.startswith("-"): + tokens.append(SearchToken.Not) + + # If there's more word, put it back to be processed again. + word = word[1:] + if word: + words.appendleft(word) + elif word.lower() == "or": + tokens.append(SearchToken.Or) + else: + # Skip stop words. + if _is_stop_word(word): + continue + + # Consecutive words are implicitly ANDed together. + if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or): + tokens.append(SearchToken.And) + + # Add the search term. + tokens.append(word) + + return tokens + + +def _tokens_to_sqlite_match_query(tokens: TokenList) -> str: + """ + Convert the list of tokens to a string suitable for passing to sqlite's MATCH. + Assume sqlite was compiled with enhanced query syntax. + + Ref: https://www.sqlite.org/fts3.html#full_text_index_queries + """ + match_query = [] + for token in tokens: + if isinstance(token, str): + match_query.append(token) + elif isinstance(token, Phrase): + match_query.append('"' + " ".join(token.phrase) + '"') + elif token == SearchToken.Not: + # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search + # term has already been added before this. + match_query.append(" NOT ") + elif token == SearchToken.Or: + match_query.append(" OR ") + elif token == SearchToken.And: + match_query.append(" AND ") + else: + raise ValueError(f"unknown token {token}") + + return "".join(match_query) + + +def _parse_query_for_sqlite(search_term: str) -> str: + """Takes a plain unicode string from the user and converts it into a form + that can be passed to sqllite's matchinfo(). + """ + return _tokens_to_sqlite_match_query(_tokenize_query(search_term)) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index af7bebee80..c801a93b5b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -33,8 +33,8 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.storage.state import StateFilter from synapse.types import JsonDict, JsonMapping, StateMap +from synapse.types.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList from synapse.util.cancellation import cancellable diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 530f04e149..cc27ec3804 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -357,6 +357,24 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: ) args.extend(event_filter.related_by_rel_types) + if event_filter.rel_types: + clauses.append( + "(%s)" + % " OR ".join( + "event_relation.relation_type = ?" for _ in event_filter.rel_types + ) + ) + args.extend(event_filter.rel_types) + + if event_filter.not_rel_types: + clauses.append( + "((%s) OR event_relation.relation_type IS NULL)" + % " AND ".join( + "event_relation.relation_type != ?" for _ in event_filter.not_rel_types + ) + ) + args.extend(event_filter.not_rel_types) + return " AND ".join(clauses), args @@ -397,6 +415,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) self._stream_order_on_start = self.get_room_max_stream_ordering() + self._min_stream_order_on_start = self.get_room_min_stream_ordering() def get_room_max_stream_ordering(self) -> int: """Get the stream_ordering of regular events that we have committed up to @@ -1024,28 +1043,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "after": {"event_ids": events_after, "token": end_token}, } - async def get_all_new_events_stream( - self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False - ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]: + async def get_all_new_event_ids_stream( + self, + from_id: int, + current_id: int, + limit: int, + ) -> Tuple[int, Dict[str, Optional[int]]]: """Get all new events - Returns all events with from_id < stream_ordering <= current_id. + Returns all event ids with from_id < stream_ordering <= current_id. Args: from_id: the stream_ordering of the last event we processed current_id: the stream_ordering of the most recently processed event limit: the maximum number of events to return - get_prev_content: whether to fetch previous event content Returns: - A tuple of (next_id, events, event_to_received_ts), where `next_id` + A tuple of (next_id, event_to_received_ts), where `next_id` is the next value to pass as `from_id` (it will either be the stream_ordering of the last returned event, or, if fewer than `limit` events were found, the `current_id`). The `event_to_received_ts` is - a dictionary mapping event ID to the event `received_ts`. + a dictionary mapping event ID to the event `received_ts`, sorted by ascending + stream_ordering. """ - def get_all_new_events_stream_txn( + def get_all_new_event_ids_stream_txn( txn: LoggingTransaction, ) -> Tuple[int, Dict[str, Optional[int]]]: sql = ( @@ -1070,15 +1092,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, event_to_received_ts upper_bound, event_to_received_ts = await self.db_pool.runInteraction( - "get_all_new_events_stream", get_all_new_events_stream_txn - ) - - events = await self.get_events_as_list( - event_to_received_ts.keys(), - get_prev_content=get_prev_content, + "get_all_new_event_ids_stream", get_all_new_event_ids_stream_txn ) - return upper_bound, events, event_to_received_ts + return upper_bound, event_to_received_ts async def get_federation_out_pos(self, typ: str) -> int: if self._need_to_reset_federation_stream_positions: @@ -1202,8 +1219,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - assert int(limit) >= 0 - # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. @@ -1282,8 +1297,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # Multiple labels could cause the same event to appear multiple times. needs_distinct = True - # If there is a filter on relation_senders and relation_types join to the - # relations table. + # If there is a relation_senders and relation_types filter join to the + # relations table to get events related to the current event. if event_filter and ( event_filter.related_by_senders or event_filter.related_by_rel_types ): @@ -1298,6 +1313,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ + # If there is a not_rel_types filter join to the relations table to get + # the event's relation information. + if event_filter and (event_filter.rel_types or event_filter.not_rel_types): + join_clause += """ + LEFT JOIN event_relations AS event_relation USING (event_id) + """ + if needs_distinct: select_keywords += " DISTINCT" diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index ddb25b5cea..14ef5b040d 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -26,6 +26,14 @@ from typing import ( cast, ) +try: + # Figure out if ICU support is available for searching users. + import icu + + USE_ICU = True +except ModuleNotFoundError: + USE_ICU = False + from typing_extensions import TypedDict from synapse.api.errors import StoreError @@ -185,9 +193,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): - who should be in the user_directory. Args: - progress (dict) - batch_size (int): Maximum number of state events to process - per cycle. + progress + batch_size: Maximum number of state events to process per cycle. Returns: number of events processed. @@ -482,7 +489,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): table="user_directory", keyvalues={"user_id": user_id}, values={"display_name": display_name, "avatar_url": avatar_url}, - lock=False, # We're only inserter ) if isinstance(self.database_engine, PostgresEngine): @@ -512,7 +518,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): table="user_directory_search", keyvalues={"user_id": user_id}, values={"value": value}, - lock=False, # We're only inserter ) else: # This should be unreachable. @@ -708,10 +713,10 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): Returns the rooms that a user is in. Args: - user_id(str): Must be a local user + user_id: Must be a local user Returns: - list: user_id + List of room IDs """ rows = await self.db_pool.simple_select_onecol( table="users_who_share_private_rooms", @@ -889,7 +894,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): limited = len(results) > limit - return {"limited": limited, "results": results} + return {"limited": limited, "results": results[0:limit]} def _parse_query_sqlite(search_term: str) -> str: @@ -903,7 +908,7 @@ def _parse_query_sqlite(search_term: str) -> str: """ # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + results = _parse_words(search_term) return " & ".join("(%s* OR %s)" % (result, result) for result in results) @@ -913,12 +918,63 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]: We use this so that we can add prefix matching, which isn't something that is supported by default. """ - - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + results = _parse_words(search_term) both = " & ".join("(%s:* | %s)" % (result, result) for result in results) exact = " & ".join("%s" % (result,) for result in results) prefix = " & ".join("%s:*" % (result,) for result in results) return both, exact, prefix + + +def _parse_words(search_term: str) -> List[str]: + """Split the provided search string into a list of its words. + + If support for ICU (International Components for Unicode) is available, use it. + Otherwise, fall back to using a regex to detect word boundaries. This latter + solution works well enough for most latin-based languages, but doesn't work as well + with other languages. + + Args: + search_term: The search string. + + Returns: + A list of the words in the search string. + """ + if USE_ICU: + return _parse_words_with_icu(search_term) + + return re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + +def _parse_words_with_icu(search_term: str) -> List[str]: + """Break down the provided search string into its individual words using ICU + (International Components for Unicode). + + Args: + search_term: The search string. + + Returns: + A list of the words in the search string. + """ + results = [] + breaker = icu.BreakIterator.createWordInstance(icu.Locale.getDefault()) + breaker.setText(search_term) + i = 0 + while True: + j = breaker.nextBoundary() + if j < 0: + break + + result = search_term[i:j] + + # libicu considers spaces and punctuation between words as words, but we don't + # want to include those in results as they would result in syntax errors in SQL + # queries (e.g. "foo bar" would result in the search query including "foo & & + # bar"). + if len(re.findall(r"([\w\-]+)", result, re.UNICODE)): + results.append(result) + + i = j + + return results