From 5442891cbca67d3af27c448791589e0b9abeb7f8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 16 Aug 2022 12:22:17 +0100 Subject: Make push rules use proper structures. (#13522) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This improves load times for push rules: | Version | Time per user | Time for 1k users | | -------------------- | ------------- | ----------------- | | Before | 138 µs | 138ms | | Now (with custom) | 2.11 µs | 2.11ms | | Now (without custom) | 49.7 ns | 0.05 ms | This therefore has a large impact on send times for rooms with large numbers of local users in the room. --- synapse/push/bulk_push_rule_evaluator.py | 37 ++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 14 deletions(-) (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 713dcf6950..ccd512be54 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -15,7 +15,18 @@ import itertools import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Union, +) from prometheus_client import Counter @@ -30,6 +41,7 @@ from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state +from .baserules import FilteredPushRules, PushRule from .push_rule_evaluator import PushRuleEvaluatorForEvent if TYPE_CHECKING: @@ -112,7 +124,7 @@ class BulkPushRuleEvaluator: async def _get_rules_for_event( self, event: EventBase, - ) -> Dict[str, List[Dict[str, Any]]]: + ) -> Dict[str, FilteredPushRules]: """Get the push rules for all users who may need to be notified about the event. @@ -186,7 +198,7 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level async def _get_mutual_relations( - self, event: EventBase, rules: Iterable[Dict[str, Any]] + self, event: EventBase, rules: Iterable[Tuple[PushRule, bool]] ) -> Dict[str, Set[Tuple[str, str]]]: """ Fetch event metadata for events which related to the same event as the given event. @@ -216,12 +228,11 @@ class BulkPushRuleEvaluator: # Pre-filter to figure out which relation types are interesting. rel_types = set() - for rule in rules: - # Skip disabled rules. - if "enabled" in rule and not rule["enabled"]: + for rule, enabled in rules: + if not enabled: continue - for condition in rule["conditions"]: + for condition in rule.conditions: if condition["kind"] != "org.matrix.msc3772.relation_match": continue @@ -254,7 +265,7 @@ class BulkPushRuleEvaluator: count_as_unread = _should_count_as_unread(event, context) rules_by_user = await self._get_rules_for_event(event) - actions_by_user: Dict[str, List[Union[dict, str]]] = {} + actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {} room_member_count = await self.store.get_number_joined_users_in_room( event.room_id @@ -317,15 +328,13 @@ class BulkPushRuleEvaluator: # current user, it'll be added to the dict later. actions_by_user[uid] = [] - for rule in rules: - if "enabled" in rule and not rule["enabled"]: + for rule, enabled in rules: + if not enabled: continue - matches = evaluator.check_conditions( - rule["conditions"], uid, display_name - ) + matches = evaluator.check_conditions(rule.conditions, uid, display_name) if matches: - actions = [x for x in rule["actions"] if x != "dont_notify"] + actions = [x for x in rule.actions if x != "dont_notify"] if actions and "notify" in actions: # Push rules say we should notify the user of this event actions_by_user[uid] = actions -- cgit 1.5.1 From 390b7ce946173b41a61f427ef25a9a1d0371ad0b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 1 Sep 2022 12:52:03 -0400 Subject: Disable calculating unread counts unless the config flag is enabled. (#13694) This avoids doing work that will never be used (since the resulting unread counts will never be sent in a /sync response). The negative of doing this is that unread counts will be incorrect when the feature is initially enabled. --- changelog.d/13694.bugfix | 1 + synapse/config/experimental.py | 3 +++ synapse/push/bulk_push_rule_evaluator.py | 7 +++++- tests/storage/test_event_push_actions.py | 42 +++++++++++++++----------------- 4 files changed, 30 insertions(+), 23 deletions(-) create mode 100644 changelog.d/13694.bugfix (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/changelog.d/13694.bugfix b/changelog.d/13694.bugfix new file mode 100644 index 0000000000..48b9bb5f0a --- /dev/null +++ b/changelog.d/13694.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.20.0 that would cause the unstable unread counts from [MSC2654](https://github.com/matrix-org/matrix-spec-proposals/pull/2654) to be calculated even if the feature is disabled. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 260db49cad..702b81e636 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -71,6 +71,9 @@ class ExperimentalConfig(Config): self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False) # MSC2654: Unread counts + # + # Note that enabling this will result in an incorrect unread count for + # previously calculated push actions. self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False) # MSC2815 (allow room moderators to view redacted event content) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index ccd512be54..d1caf8a0f7 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -262,7 +262,12 @@ class BulkPushRuleEvaluator: # This can happen due to out of band memberships return - count_as_unread = _should_count_as_unread(event, context) + # Disable counting as unread unless the experimental configuration is + # enabled, as it can cause additional (unwanted) rows to be added to the + # event_push_actions table. + count_as_unread = False + if self.hs.config.experimental.msc2654_enabled: + count_as_unread = _should_count_as_unread(event, context) rules_by_user = await self._get_rules_for_event(event) actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {} diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 62fd4aeb2f..fc43d7edd1 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -67,9 +67,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): last_event_id: str - def _assert_counts( - noitf_count: int, unread_count: int, highlight_count: int - ) -> None: + def _assert_counts(noitf_count: int, highlight_count: int) -> None: counts = self.get_success( self.store.db_pool.runInteraction( "get-unread-counts", @@ -82,7 +80,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): counts, NotifCounts( notify_count=noitf_count, - unread_count=unread_count, + unread_count=0, highlight_count=highlight_count, ), ) @@ -112,27 +110,27 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) ) - _assert_counts(0, 0, 0) + _assert_counts(0, 0) _create_event() - _assert_counts(1, 1, 0) + _assert_counts(1, 0) _rotate() - _assert_counts(1, 1, 0) + _assert_counts(1, 0) event_id = _create_event() - _assert_counts(2, 2, 0) + _assert_counts(2, 0) _rotate() - _assert_counts(2, 2, 0) + _assert_counts(2, 0) _create_event() _mark_read(event_id) - _assert_counts(1, 1, 0) + _assert_counts(1, 0) _mark_read(last_event_id) - _assert_counts(0, 0, 0) + _assert_counts(0, 0) _create_event() _rotate() - _assert_counts(1, 1, 0) + _assert_counts(1, 0) # Delete old event push actions, this should not affect the (summarised) count. # @@ -151,35 +149,35 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) ) self.assertEqual(result, []) - _assert_counts(1, 1, 0) + _assert_counts(1, 0) _mark_read(last_event_id) - _assert_counts(0, 0, 0) + _assert_counts(0, 0) event_id = _create_event(True) - _assert_counts(1, 1, 1) + _assert_counts(1, 1) _rotate() - _assert_counts(1, 1, 1) + _assert_counts(1, 1) # Check that adding another notification and rotating after highlight # works. _create_event() _rotate() - _assert_counts(2, 2, 1) + _assert_counts(2, 1) # Check that sending read receipts at different points results in the # right counts. _mark_read(event_id) - _assert_counts(1, 1, 0) + _assert_counts(1, 0) _mark_read(last_event_id) - _assert_counts(0, 0, 0) + _assert_counts(0, 0) _create_event(True) - _assert_counts(1, 1, 1) + _assert_counts(1, 1) _mark_read(last_event_id) - _assert_counts(0, 0, 0) + _assert_counts(0, 0) _rotate() - _assert_counts(0, 0, 0) + _assert_counts(0, 0) def test_find_first_stream_ordering_after_ts(self) -> None: def add_event(so: int, ts: int) -> None: -- cgit 1.5.1 From 666ae877292d4747b9441105e3df8558f7a335c0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 14 Sep 2022 13:11:16 -0400 Subject: Update event push action and receipt tables to support threads. (#13753) Adds a `thread_id` column to the `event_push_actions`, `event_push_actions_staging`, and `event_push_summary` tables. This will notifications to be segmented by the thread in a future pull request. The `thread_id` column stores the root event ID or the special value `"main"`. The `thread_id` column for `event_push_actions` and `event_push_summary` is backfilled with `"main"` for all existing rows. New entries into `event_push_actions` and `event_push_actions_staging` will get the proper thread ID. `receipts_linearized` and `receipts_graph` also gain a `thread_id` column, which is similar, except `NULL` is a special value meaning the receipt is "unthreaded". See MSC3771 and MSC3773 for where this data will be useful. --- changelog.d/13753.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 29 ++--- .../storage/databases/main/event_push_actions.py | 121 ++++++++++++++++++++- synapse/storage/databases/main/events.py | 4 +- synapse/storage/databases/main/receipts.py | 20 ++++ synapse/storage/schema/__init__.py | 6 +- .../main/delta/72/06thread_notifications.sql | 30 +++++ .../main/delta/72/07thread_receipts.sql.postgres | 30 +++++ .../main/delta/72/07thread_receipts.sql.sqlite | 70 ++++++++++++ .../schema/main/delta/72/08thread_receipts.sql | 20 ++++ tests/replication/slave/storage/test_events.py | 1 + 11 files changed, 312 insertions(+), 20 deletions(-) create mode 100644 changelog.d/13753.misc create mode 100644 synapse/storage/schema/main/delta/72/06thread_notifications.sql create mode 100644 synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres create mode 100644 synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite create mode 100644 synapse/storage/schema/main/delta/72/08thread_receipts.sql (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/changelog.d/13753.misc b/changelog.d/13753.misc new file mode 100644 index 0000000000..63de2eb9f9 --- /dev/null +++ b/changelog.d/13753.misc @@ -0,0 +1 @@ +Prepatory work for storing thread IDs for notifications and receipts. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d1caf8a0f7..3846fbc5f0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -198,7 +198,7 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level async def _get_mutual_relations( - self, event: EventBase, rules: Iterable[Tuple[PushRule, bool]] + self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]] ) -> Dict[str, Set[Tuple[str, str]]]: """ Fetch event metadata for events which related to the same event as the given event. @@ -206,7 +206,7 @@ class BulkPushRuleEvaluator: If the given event has no relation information, returns an empty dictionary. Args: - event_id: The event ID which is targeted by relations. + parent_id: The event ID which is targeted by relations. rules: The push rules which will be processed for this event. Returns: @@ -220,12 +220,6 @@ class BulkPushRuleEvaluator: if not self._relations_match_enabled: return {} - # If the event does not have a relation, then cannot have any mutual - # relations. - relation = relation_from_event(event) - if not relation: - return {} - # Pre-filter to figure out which relation types are interesting. rel_types = set() for rule, enabled in rules: @@ -246,9 +240,7 @@ class BulkPushRuleEvaluator: return {} # If any valid rules were found, fetch the mutual relations. - return await self.store.get_mutual_event_relations( - relation.parent_id, rel_types - ) + return await self.store.get_mutual_event_relations(parent_id, rel_types) @measure_func("action_for_event_by_user") async def action_for_event_by_user( @@ -281,9 +273,17 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) - relations = await self._get_mutual_relations( - event, itertools.chain(*rules_by_user.values()) - ) + relation = relation_from_event(event) + # If the event does not have a relation, then cannot have any mutual + # relations or thread ID. + relations = {} + thread_id = "main" + if relation: + relations = await self._get_mutual_relations( + relation.parent_id, itertools.chain(*rules_by_user.values()) + ) + if relation.rel_type == RelationTypes.THREAD: + thread_id = relation.parent_id evaluator = PushRuleEvaluatorForEvent( event, @@ -352,6 +352,7 @@ class BulkPushRuleEvaluator: event.event_id, actions_by_user, count_as_unread, + thread_id, ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3a3fb8c507..6b8668d2dc 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -98,6 +98,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -232,6 +233,104 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas replaces_index="event_push_summary_user_rm", ) + self.db_pool.updates.register_background_index_update( + "event_push_summary_unique_index2", + index_name="event_push_summary_unique_index2", + table="event_push_summary", + columns=["user_id", "room_id", "thread_id"], + unique=True, + ) + + self.db_pool.updates.register_background_update_handler( + "event_push_backfill_thread_id", + self._background_backfill_thread_id, + ) + + async def _background_backfill_thread_id( + self, progress: JsonDict, batch_size: int + ) -> int: + """ + Fill in the thread_id field for event_push_actions and event_push_summary. + + This is preparatory so that it can be made non-nullable in the future. + + Because all current (null) data is done in an unthreaded manner this + simply assumes it is on the "main" timeline. Since event_push_actions + are periodically cleared it is not possible to correctly re-calculate + the thread_id. + """ + event_push_actions_done = progress.get("event_push_actions_done", False) + + def add_thread_id_txn( + txn: LoggingTransaction, table_name: str, start_stream_ordering: int + ) -> int: + sql = f""" + SELECT stream_ordering + FROM {table_name} + WHERE + thread_id IS NULL + AND stream_ordering > ? + ORDER BY stream_ordering + LIMIT ? + """ + txn.execute(sql, (start_stream_ordering, batch_size)) + + # No more rows to process. + rows = txn.fetchall() + if not rows: + progress[f"{table_name}_done"] = True + self.db_pool.updates._background_update_progress_txn( + txn, "event_push_backfill_thread_id", progress + ) + return 0 + + # Update the thread ID for any of those rows. + max_stream_ordering = rows[-1][0] + + sql = f""" + UPDATE {table_name} + SET thread_id = 'main' + WHERE stream_ordering <= ? AND thread_id IS NULL + """ + txn.execute(sql, (max_stream_ordering,)) + + # Update progress. + processed_rows = txn.rowcount + progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering + self.db_pool.updates._background_update_progress_txn( + txn, "event_push_backfill_thread_id", progress + ) + + return processed_rows + + # First update the event_push_actions table, then the event_push_summary table. + # + # Note that the event_push_actions_staging table is ignored since it is + # assumed that items in that table will only exist for a short period of + # time. + if not event_push_actions_done: + result = await self.db_pool.runInteraction( + "event_push_backfill_thread_id", + add_thread_id_txn, + "event_push_actions", + progress.get("max_event_push_actions_stream_ordering", 0), + ) + else: + result = await self.db_pool.runInteraction( + "event_push_backfill_thread_id", + add_thread_id_txn, + "event_push_summary", + progress.get("max_event_push_summary_stream_ordering", 0), + ) + + # Only done after the event_push_summary table is done. + if not result: + await self.db_pool.updates._end_background_update( + "event_push_backfill_thread_id" + ) + + return result + @cached(tree=True, max_entries=5000) async def get_unread_event_push_actions_by_room_for_user( self, @@ -670,6 +769,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas event_id: str, user_id_actions: Dict[str, Collection[Union[Mapping, str]]], count_as_unread: bool, + thread_id: str, ) -> None: """Add the push actions for the event to the push action staging area. @@ -678,6 +778,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id_actions: A mapping of user_id to list of push actions, where an action can either be a string or dict. count_as_unread: Whether this event should increment unread counts. + thread_id: The thread this event is parent of, if applicable. """ if not user_id_actions: return @@ -686,7 +787,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # can be used to insert into the `event_push_actions_staging` table. def _gen_entry( user_id: str, actions: Collection[Union[Mapping, str]] - ) -> Tuple[str, str, str, int, int, int]: + ) -> Tuple[str, str, str, int, int, int, str]: is_highlight = 1 if _action_has_highlight(actions) else 0 notif = 1 if "notify" in actions else 0 return ( @@ -696,11 +797,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas notif, # notif column is_highlight, # highlight column int(count_as_unread), # unread column + thread_id, # thread_id column ) await self.db_pool.simple_insert_many( "event_push_actions_staging", - keys=("event_id", "user_id", "actions", "notif", "highlight", "unread"), + keys=( + "event_id", + "user_id", + "actions", + "notif", + "highlight", + "unread", + "thread_id", + ), values=[ _gen_entry(user_id, actions) for user_id, actions in user_id_actions.items() @@ -981,6 +1091,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) # Replace the previous summary with the new counts. + # + # TODO(threads): Upsert per-thread instead of setting them all to main. self.db_pool.simple_upsert_txn( txn, table="event_push_summary", @@ -990,6 +1102,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas "unread_count": unread_count, "stream_ordering": old_rotate_stream_ordering, "last_receipt_stream_ordering": stream_ordering, + "thread_id": "main", }, ) @@ -1138,17 +1251,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas logger.info("Rotating notifications, handling %d rows", len(summaries)) + # TODO(threads): Update on a per-thread basis. self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", key_names=("user_id", "room_id"), key_values=[(user_id, room_id) for user_id, room_id in summaries], - value_names=("notif_count", "unread_count", "stream_ordering"), + value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"), value_values=[ ( summary.notif_count, summary.unread_count, summary.stream_ordering, + "main", ) for summary in summaries.values() ], diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index a4010ee28d..c0b4080e4b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2192,9 +2192,9 @@ class PersistEventsStore: sql = """ INSERT INTO event_push_actions ( room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight, unread + topological_ordering, notif, highlight, unread, thread_id ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id FROM event_push_actions_staging WHERE event_id = ? """ diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 719a12b0ae..ddb8e80b69 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -113,6 +113,24 @@ class ReceiptsWorkerStore(SQLBaseStore): prefilled_cache=receipts_stream_prefill, ) + self.db_pool.updates.register_background_index_update( + "receipts_linearized_unique_index", + index_name="receipts_linearized_unique_index", + table="receipts_linearized", + columns=["room_id", "receipt_type", "user_id"], + where_clause="thread_id IS NULL", + unique=True, + ) + + self.db_pool.updates.register_background_index_update( + "receipts_graph_unique_index", + index_name="receipts_graph_unique_index", + table="receipts_graph", + columns=["room_id", "receipt_type", "user_id"], + where_clause="thread_id IS NULL", + unique=True, + ) + def get_max_receipt_stream_id(self) -> int: """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @@ -677,6 +695,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "event_id": event_id, "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), + "thread_id": None, }, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock @@ -824,6 +843,7 @@ class ReceiptsWorkerStore(SQLBaseStore): values={ "event_ids": json_encoder.encode(event_ids), "data": json_encoder.encode(data), + "thread_id": None, }, # receipts_graph has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 32cda5e3ba..38c9532bfd 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 72 # remember to update the list below when updating +SCHEMA_VERSION = 73 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -77,6 +77,10 @@ Changes in SCHEMA_VERSION = 72: - Tables related to groups are dropped. - Unused column application_services_state.last_txn is dropped - Cache invalidation stream id sequence now begins at 2 to match code expectation. + +Changes in SCHEMA_VERSION = 73; + - thread_id column is added to event_push_actions, event_push_actions_staging + event_push_summary, receipts_linearized, and receipts_graph. """ diff --git a/synapse/storage/schema/main/delta/72/06thread_notifications.sql b/synapse/storage/schema/main/delta/72/06thread_notifications.sql new file mode 100644 index 0000000000..2f4f5dac7a --- /dev/null +++ b/synapse/storage/schema/main/delta/72/06thread_notifications.sql @@ -0,0 +1,30 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a nullable column for thread ID to the event push actions tables; this +-- will be filled in with a default value for any previously existing rows. +-- +-- After migration this can be made non-nullable. + +ALTER TABLE event_push_actions_staging ADD COLUMN thread_id TEXT; +ALTER TABLE event_push_actions ADD COLUMN thread_id TEXT; +ALTER TABLE event_push_summary ADD COLUMN thread_id TEXT; + +-- Update the unique index for `event_push_summary`. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7006, 'event_push_summary_unique_index2', '{}'); + +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (7006, 'event_push_backfill_thread_id', '{}', 'event_push_summary_unique_index2'); diff --git a/synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres new file mode 100644 index 0000000000..55fff9e278 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres @@ -0,0 +1,30 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a nullable column for thread ID to the receipts table; this allows a +-- receipt per user, per room, as well as an unthreaded receipt (corresponding +-- to a null thread ID). + +ALTER TABLE receipts_linearized ADD COLUMN thread_id TEXT; +ALTER TABLE receipts_graph ADD COLUMN thread_id TEXT; + +-- Rebuild the unique constraint with the thread_id. +ALTER TABLE receipts_linearized + ADD CONSTRAINT receipts_linearized_uniqueness_thread + UNIQUE (room_id, receipt_type, user_id, thread_id); + +ALTER TABLE receipts_graph + ADD CONSTRAINT receipts_graph_uniqueness_thread + UNIQUE (room_id, receipt_type, user_id, thread_id); diff --git a/synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite new file mode 100644 index 0000000000..232f67deb4 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite @@ -0,0 +1,70 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Allow multiple receipts per user per room via a nullable thread_id column. +-- +-- SQLite doesn't support modifying constraints to an existing table, so it must +-- be recreated. + +-- Create the new tables. +CREATE TABLE receipts_linearized_new ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + thread_id TEXT, + event_stream_ordering BIGINT, + data TEXT NOT NULL, + CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id), + CONSTRAINT receipts_linearized_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id) +); + +CREATE TABLE receipts_graph_new ( + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_ids TEXT NOT NULL, + thread_id TEXT, + data TEXT NOT NULL, + CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id), + CONSTRAINT receipts_graph_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id) +); + +-- Drop the old indexes. +DROP INDEX IF EXISTS receipts_linearized_id; +DROP INDEX IF EXISTS receipts_linearized_room_stream; +DROP INDEX IF EXISTS receipts_linearized_user; + +-- Copy the data. +INSERT INTO receipts_linearized_new (stream_id, room_id, receipt_type, user_id, event_id, event_stream_ordering, data) + SELECT stream_id, room_id, receipt_type, user_id, event_id, event_stream_ordering, data + FROM receipts_linearized; +INSERT INTO receipts_graph_new (room_id, receipt_type, user_id, event_ids, data) + SELECT room_id, receipt_type, user_id, event_ids, data + FROM receipts_graph; + +-- Drop the old tables. +DROP TABLE receipts_linearized; +DROP TABLE receipts_graph; + +-- Rename the tables. +ALTER TABLE receipts_linearized_new RENAME TO receipts_linearized; +ALTER TABLE receipts_graph_new RENAME TO receipts_graph; + +-- Create the indices. +CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id ); +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); +CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id ); diff --git a/synapse/storage/schema/main/delta/72/08thread_receipts.sql b/synapse/storage/schema/main/delta/72/08thread_receipts.sql new file mode 100644 index 0000000000..e35b021f31 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/08thread_receipts.sql @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7007, 'receipts_linearized_unique_index', '{}'); + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7007, 'receipts_graph_unique_index', '{}'); diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 531a0db2d0..49a21e2e85 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -404,6 +404,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event.event_id, {user_id: actions for user_id, actions in push_actions}, False, + "main", ) ) return event, context -- cgit 1.5.1 From 42d261c32f13e2de7494a0ade77c1f7b646af1fe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 20 Sep 2022 12:10:31 +0100 Subject: Port the push rule classes to Rust. (#13768) --- .rustfmt.toml | 1 + changelog.d/13768.misc | 1 + rust/Cargo.toml | 10 +- rust/src/lib.rs | 9 +- rust/src/push/base_rules.rs | 335 ++++++++++++++++ rust/src/push/mod.rs | 502 ++++++++++++++++++++++++ stubs/synapse/synapse_rust.pyi | 2 - stubs/synapse/synapse_rust/__init__.pyi | 2 + stubs/synapse/synapse_rust/push.pyi | 37 ++ synapse/handlers/push_rules.py | 5 +- synapse/push/baserules.py | 583 ---------------------------- synapse/push/bulk_push_rule_evaluator.py | 7 +- synapse/push/clientformat.py | 5 +- synapse/storage/databases/main/push_rule.py | 23 +- tests/handlers/test_deactivate_account.py | 27 +- 15 files changed, 932 insertions(+), 617 deletions(-) create mode 100644 .rustfmt.toml create mode 100644 changelog.d/13768.misc create mode 100644 rust/src/push/base_rules.rs create mode 100644 rust/src/push/mod.rs delete mode 100644 stubs/synapse/synapse_rust.pyi create mode 100644 stubs/synapse/synapse_rust/__init__.pyi create mode 100644 stubs/synapse/synapse_rust/push.pyi delete mode 100644 synapse/push/baserules.py (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000000..bf96e7743d --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1 @@ +group_imports = "StdExternalCrate" diff --git a/changelog.d/13768.misc b/changelog.d/13768.misc new file mode 100644 index 0000000000..28bddb7059 --- /dev/null +++ b/changelog.d/13768.misc @@ -0,0 +1 @@ +Port push rules to using Rust. diff --git a/rust/Cargo.toml b/rust/Cargo.toml index deddf3cec2..8dc5f93ff1 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -18,7 +18,15 @@ crate-type = ["cdylib"] name = "synapse.synapse_rust" [dependencies] -pyo3 = { version = "0.16.5", features = ["extension-module", "macros", "abi3", "abi3-py37"] } +anyhow = "1.0.63" +lazy_static = "1.4.0" +log = "0.4.17" +pyo3 = { version = "0.17.1", features = ["extension-module", "macros", "anyhow", "abi3", "abi3-py37"] } +pyo3-log = "0.7.0" +pythonize = "0.17.0" +regex = "1.6.0" +serde = { version = "1.0.144", features = ["derive"] } +serde_json = "1.0.85" [build-dependencies] blake2 = "0.10.4" diff --git a/rust/src/lib.rs b/rust/src/lib.rs index ba42465fb8..c7b60e58a7 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,5 +1,7 @@ use pyo3::prelude::*; +pub mod push; + /// Returns the hash of all the rust source files at the time it was compiled. /// /// Used by python to detect if the rust library is outdated. @@ -17,8 +19,13 @@ fn sum_as_string(a: usize, b: usize) -> PyResult { /// The entry point for defining the Python module. #[pymodule] -fn synapse_rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> { +fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> { + pyo3_log::init(); + m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?; + + push::register_module(py, m)?; + Ok(()) } diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs new file mode 100644 index 0000000000..7c62bc4849 --- /dev/null +++ b/rust/src/push/base_rules.rs @@ -0,0 +1,335 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Contains the definitions of the "base" push rules. + +use std::borrow::Cow; +use std::collections::HashMap; + +use lazy_static::lazy_static; +use serde_json::Value; + +use super::KnownCondition; +use crate::push::Action; +use crate::push::Condition; +use crate::push::EventMatchCondition; +use crate::push::PushRule; +use crate::push::SetTweak; +use crate::push::TweakValue; + +const HIGHLIGHT_ACTION: Action = Action::SetTweak(SetTweak { + set_tweak: Cow::Borrowed("highlight"), + value: None, + other_keys: Value::Null, +}); + +const HIGHLIGHT_FALSE_ACTION: Action = Action::SetTweak(SetTweak { + set_tweak: Cow::Borrowed("highlight"), + value: Some(TweakValue::Other(Value::Bool(false))), + other_keys: Value::Null, +}); + +const SOUND_ACTION: Action = Action::SetTweak(SetTweak { + set_tweak: Cow::Borrowed("sound"), + value: Some(TweakValue::String(Cow::Borrowed("default"))), + other_keys: Value::Null, +}); + +const RING_ACTION: Action = Action::SetTweak(SetTweak { + set_tweak: Cow::Borrowed("sound"), + value: Some(TweakValue::String(Cow::Borrowed("ring"))), + other_keys: Value::Null, +}); + +pub const BASE_PREPEND_OVERRIDE_RULES: &[PushRule] = &[PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.master"), + priority_class: 5, + conditions: Cow::Borrowed(&[]), + actions: Cow::Borrowed(&[Action::DontNotify]), + default: true, + default_enabled: false, +}]; + +pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.suppress_notices"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("content.msgtype"), + pattern: Some(Cow::Borrowed("m.notice")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::DontNotify]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.invite_for_me"), + priority_class: 5, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.member")), + pattern_type: None, + })), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("content.membership"), + pattern: Some(Cow::Borrowed("invite")), + pattern_type: None, + })), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("state_key"), + pattern: None, + pattern_type: Some(Cow::Borrowed("user_id")), + })), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.member_event"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.member")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::DontNotify]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::ContainsDisplayName)]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.roomnotif"), + priority_class: 5, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::SenderNotificationPermission { + key: Cow::Borrowed("room"), + }), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("content.body"), + pattern: Some(Cow::Borrowed("@room")), + pattern_type: None, + })), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.tombstone"), + priority_class: 5, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.tombstone")), + pattern_type: None, + })), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("state_key"), + pattern: Some(Cow::Borrowed("")), + pattern_type: None, + })), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.m.rule.reaction"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.reaction")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::DontNotify]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/override/.org.matrix.msc3786.rule.room.server_acl"), + priority_class: 5, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.server_acl")), + pattern_type: None, + })), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("state_key"), + pattern: Some(Cow::Borrowed("")), + pattern_type: None, + })), + ]), + actions: Cow::Borrowed(&[]), + default: true, + default_enabled: true, + }, +]; + +pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule { + rule_id: Cow::Borrowed("global/content/.m.rule.contains_user_name"), + priority_class: 4, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("content.body"), + pattern: None, + pattern_type: Some(Cow::Borrowed("user_localpart")), + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, +}]; + +pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ + PushRule { + rule_id: Cow::Borrowed("global/underride/.m.rule.call"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.call.invite")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify, RING_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.m.rule.room_one_to_one"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.message")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.m.rule.encrypted_room_one_to_one"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.encrypted")), + pattern_type: None, + })), + Condition::Known(KnownCondition::RoomMemberCount { + is: Some(Cow::Borrowed("2")), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3772.thread_reply"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch { + rel_type: Cow::Borrowed("m.thread"), + sender: None, + sender_type: Some(Cow::Borrowed("user_id")), + })]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.m.rule.message"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.message")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.m.rule.encrypted"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( + EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("m.room.encrypted")), + pattern_type: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/underride/.im.vector.jitsi"), + priority_class: 1, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("type"), + pattern: Some(Cow::Borrowed("im.vector.modular.widgets")), + pattern_type: None, + })), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("content.type"), + pattern: Some(Cow::Borrowed("jitsi")), + pattern_type: None, + })), + Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: Cow::Borrowed("state_key"), + pattern: Some(Cow::Borrowed("*")), + pattern_type: None, + })), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), + default: true, + default_enabled: true, + }, +]; + +lazy_static! { + pub static ref BASE_RULES_BY_ID: HashMap<&'static str, &'static PushRule> = + BASE_PREPEND_OVERRIDE_RULES + .iter() + .chain(BASE_APPEND_OVERRIDE_RULES.iter()) + .chain(BASE_APPEND_CONTENT_RULES.iter()) + .chain(BASE_APPEND_UNDERRIDE_RULES.iter()) + .map(|rule| { (&*rule.rule_id, rule) }) + .collect(); +} diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs new file mode 100644 index 0000000000..de6764e7c5 --- /dev/null +++ b/rust/src/push/mod.rs @@ -0,0 +1,502 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! An implementation of Matrix push rules. +//! +//! The `Cow<_>` type is used extensively within this module to allow creating +//! the base rules as constants (in Rust constants can't require explicit +//! allocation atm). +//! +//! --- +//! +//! Push rules is the system used to determine which events trigger a push (and a +//! bump in notification counts). +//! +//! This consists of a list of "push rules" for each user, where a push rule is a +//! pair of "conditions" and "actions". When a user receives an event Synapse +//! iterates over the list of push rules until it finds one where all the conditions +//! match the event, at which point "actions" describe the outcome (e.g. notify, +//! highlight, etc). +//! +//! Push rules are split up into 5 different "kinds" (aka "priority classes"), which +//! are run in order: +//! 1. Override — highest priority rules, e.g. always ignore notices +//! 2. Content — content specific rules, e.g. @ notifications +//! 3. Room — per room rules, e.g. enable/disable notifications for all messages +//! in a room +//! 4. Sender — per sender rules, e.g. never notify for messages from a given +//! user +//! 5. Underride — the lowest priority "default" rules, e.g. notify for every +//! message. +//! +//! The set of "base rules" are the list of rules that every user has by default. A +//! user can modify their copy of the push rules in one of three ways: +//! +//! 1. Adding a new push rule of a certain kind +//! 2. Changing the actions of a base rule +//! 3. Enabling/disabling a base rule. +//! +//! The base rules are split into whether they come before or after a particular +//! kind, so the order of push rule evaluation would be: base rules for before +//! "override" kind, user defined "override" rules, base rules after "override" +//! kind, etc, etc. + +use std::borrow::Cow; +use std::collections::{BTreeMap, HashMap, HashSet}; + +use anyhow::{Context, Error}; +use log::warn; +use pyo3::prelude::*; +use pythonize::pythonize; +use serde::de::Error as _; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +mod base_rules; + +/// Called when registering modules with python. +pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { + let child_module = PyModule::new(py, "push")?; + child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?; + + m.add_submodule(child_module)?; + + // We need to manually add the module to sys.modules to make `from + // synapse.synapse_rust import push` work. + py.import("sys")? + .getattr("modules")? + .set_item("synapse.synapse_rust.push", child_module)?; + + Ok(()) +} + +#[pyfunction] +fn get_base_rule_ids() -> HashSet<&'static str> { + base_rules::BASE_RULES_BY_ID.keys().copied().collect() +} + +/// A single push rule for a user. +#[derive(Debug, Clone)] +#[pyclass(frozen)] +pub struct PushRule { + /// A unique ID for this rule + pub rule_id: Cow<'static, str>, + /// The "kind" of push rule this is (see `PRIORITY_CLASS_MAP` in Python) + #[pyo3(get)] + pub priority_class: i32, + /// The conditions that must all match for actions to be applied + pub conditions: Cow<'static, [Condition]>, + /// The actions to apply if all conditions are met + pub actions: Cow<'static, [Action]>, + /// Whether this is a base rule + #[pyo3(get)] + pub default: bool, + /// Whether this is enabled by default + #[pyo3(get)] + pub default_enabled: bool, +} + +#[pymethods] +impl PushRule { + #[staticmethod] + pub fn from_db( + rule_id: String, + priority_class: i32, + conditions: &str, + actions: &str, + ) -> Result { + let conditions = serde_json::from_str(conditions).context("parsing conditions")?; + let actions = serde_json::from_str(actions).context("parsing actions")?; + + Ok(PushRule { + rule_id: Cow::Owned(rule_id), + priority_class, + conditions, + actions, + default: false, + default_enabled: true, + }) + } + + #[getter] + fn rule_id(&self) -> &str { + &self.rule_id + } + + #[getter] + fn actions(&self) -> Vec { + self.actions.clone().into_owned() + } + + #[getter] + fn conditions(&self) -> Vec { + self.conditions.clone().into_owned() + } + + fn __repr__(&self) -> String { + format!( + "", + self.rule_id, self.conditions, self.actions + ) + } +} + +/// The "action" Synapse should perform for a matching push rule. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Action { + DontNotify, + Notify, + Coalesce, + SetTweak(SetTweak), + + // An unrecognized custom action. + Unknown(Value), +} + +impl IntoPy for Action { + fn into_py(self, py: Python<'_>) -> PyObject { + // When we pass the `Action` struct to Python we want it to be converted + // to a dict. We use `pythonize`, which converts the struct using the + // `serde` serialization. + pythonize(py, &self).expect("valid action") + } +} + +/// The body of a `SetTweak` push action. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct SetTweak { + set_tweak: Cow<'static, str>, + + #[serde(skip_serializing_if = "Option::is_none")] + value: Option, + + // This picks up any other fields that may have been added by clients. + // These get added when we convert the `Action` to a python object. + #[serde(flatten)] + other_keys: Value, +} + +/// The value of a `set_tweak`. +/// +/// We need this (rather than using `TweakValue` directly) so that we can use +/// `&'static str` in the value when defining the constant base rules. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(untagged)] +pub enum TweakValue { + String(Cow<'static, str>), + Other(Value), +} + +impl Serialize for Action { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Action::DontNotify => serializer.serialize_str("dont_notify"), + Action::Notify => serializer.serialize_str("notify"), + Action::Coalesce => serializer.serialize_str("coalesce"), + Action::SetTweak(tweak) => tweak.serialize(serializer), + Action::Unknown(value) => value.serialize(serializer), + } + } +} + +/// Simple helper class for deserializing Action from JSON. +#[derive(Deserialize)] +#[serde(untagged)] +enum ActionDeserializeHelper { + Str(String), + SetTweak(SetTweak), + Unknown(Value), +} + +impl<'de> Deserialize<'de> for Action { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let helper: ActionDeserializeHelper = Deserialize::deserialize(deserializer)?; + match helper { + ActionDeserializeHelper::Str(s) => match &*s { + "dont_notify" => Ok(Action::DontNotify), + "notify" => Ok(Action::Notify), + "coalesce" => Ok(Action::Coalesce), + _ => Err(D::Error::custom("unrecognized action")), + }, + ActionDeserializeHelper::SetTweak(set_tweak) => Ok(Action::SetTweak(set_tweak)), + ActionDeserializeHelper::Unknown(value) => Ok(Action::Unknown(value)), + } + } +} + +/// A condition used in push rules to match against an event. +/// +/// We need this split as `serde` doesn't give us the ability to have a +/// "catchall" variant in tagged enums. +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum Condition { + /// A recognized condition that we can match against + Known(KnownCondition), + /// An unrecognized condition that we ignore. + Unknown(Value), +} + +/// The set of "known" conditions that we can handle. +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +#[serde(tag = "kind")] +pub enum KnownCondition { + EventMatch(EventMatchCondition), + ContainsDisplayName, + RoomMemberCount { + #[serde(skip_serializing_if = "Option::is_none")] + is: Option>, + }, + SenderNotificationPermission { + key: Cow<'static, str>, + }, + #[serde(rename = "org.matrix.msc3772.relation_match")] + RelationMatch { + rel_type: Cow<'static, str>, + #[serde(skip_serializing_if = "Option::is_none")] + sender: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + sender_type: Option>, + }, +} + +impl IntoPy for Condition { + fn into_py(self, py: Python<'_>) -> PyObject { + pythonize(py, &self).expect("valid condition") + } +} + +/// The body of a [`Condition::EventMatch`] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct EventMatchCondition { + key: Cow<'static, str>, + #[serde(skip_serializing_if = "Option::is_none")] + pattern: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pattern_type: Option>, +} + +/// The collection of push rules for a user. +#[derive(Debug, Clone, Default)] +#[pyclass(frozen)] +struct PushRules { + /// Custom push rules that override a base rule. + overridden_base_rules: HashMap, PushRule>, + + /// Custom rules that come between the prepend/append override base rules. + override_rules: Vec, + /// Custom rules that come before the base content rules. + content: Vec, + /// Custom rules that come before the base room rules. + room: Vec, + /// Custom rules that come before the base sender rules. + sender: Vec, + /// Custom rules that come before the base underride rules. + underride: Vec, +} + +#[pymethods] +impl PushRules { + #[new] + fn new(rules: Vec) -> PushRules { + let mut push_rules: PushRules = Default::default(); + + for rule in rules { + if let Some(&o) = base_rules::BASE_RULES_BY_ID.get(&*rule.rule_id) { + push_rules.overridden_base_rules.insert( + rule.rule_id.clone(), + PushRule { + actions: rule.actions.clone(), + ..o.clone() + }, + ); + + continue; + } + + match rule.priority_class { + 5 => push_rules.override_rules.push(rule), + 4 => push_rules.content.push(rule), + 3 => push_rules.room.push(rule), + 2 => push_rules.sender.push(rule), + 1 => push_rules.underride.push(rule), + _ => { + warn!( + "Unrecognized priority class for rule {}: {}", + rule.rule_id, rule.priority_class + ); + } + } + } + + push_rules + } + + /// Returns the list of all rules, including base rules, in the order they + /// should be executed in. + fn rules(&self) -> Vec { + self.iter().cloned().collect() + } +} + +impl PushRules { + /// Iterates over all the rules, including base rules, in the order they + /// should be executed in. + pub fn iter(&self) -> impl Iterator { + base_rules::BASE_PREPEND_OVERRIDE_RULES + .iter() + .chain(self.override_rules.iter()) + .chain(base_rules::BASE_APPEND_OVERRIDE_RULES.iter()) + .chain(self.content.iter()) + .chain(base_rules::BASE_APPEND_CONTENT_RULES.iter()) + .chain(self.room.iter()) + .chain(self.sender.iter()) + .chain(self.underride.iter()) + .chain(base_rules::BASE_APPEND_UNDERRIDE_RULES.iter()) + .map(|rule| { + self.overridden_base_rules + .get(&*rule.rule_id) + .unwrap_or(rule) + }) + } +} + +/// A wrapper around `PushRules` that checks the enabled state of rules and +/// filters out disabled experimental rules. +#[derive(Debug, Clone, Default)] +#[pyclass(frozen)] +pub struct FilteredPushRules { + push_rules: PushRules, + enabled_map: BTreeMap, + msc3786_enabled: bool, + msc3772_enabled: bool, +} + +#[pymethods] +impl FilteredPushRules { + #[new] + fn py_new( + push_rules: PushRules, + enabled_map: BTreeMap, + msc3786_enabled: bool, + msc3772_enabled: bool, + ) -> Self { + Self { + push_rules, + enabled_map, + msc3786_enabled, + msc3772_enabled, + } + } + + /// Returns the list of all rules and their enabled state, including base + /// rules, in the order they should be executed in. + fn rules(&self) -> Vec<(PushRule, bool)> { + self.iter().map(|(r, e)| (r.clone(), e)).collect() + } +} + +impl FilteredPushRules { + /// Iterates over all the rules and their enabled state, including base + /// rules, in the order they should be executed in. + fn iter(&self) -> impl Iterator { + self.push_rules + .iter() + .filter(|rule| { + // Ignore disabled experimental push rules + if !self.msc3786_enabled + && rule.rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl" + { + return false; + } + + if !self.msc3772_enabled + && rule.rule_id == "global/underride/.org.matrix.msc3772.thread_reply" + { + return false; + } + + true + }) + .map(|r| { + let enabled = *self + .enabled_map + .get(&*r.rule_id) + .unwrap_or(&r.default_enabled); + (r, enabled) + }) + } +} + +#[test] +fn test_serialize_condition() { + let condition = Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + key: "content.body".into(), + pattern: Some("coffee".into()), + pattern_type: None, + })); + + let json = serde_json::to_string(&condition).unwrap(); + assert_eq!( + json, + r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"# + ) +} + +#[test] +fn test_deserialize_condition() { + let json = r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#; + + let _: Condition = serde_json::from_str(json).unwrap(); +} + +#[test] +fn test_deserialize_custom_condition() { + let json = r#"{"kind":"custom_tag"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!(condition, Condition::Unknown(_))); + + let new_json = serde_json::to_string(&condition).unwrap(); + assert_eq!(json, new_json); +} + +#[test] +fn test_deserialize_action() { + let _: Action = serde_json::from_str(r#""notify""#).unwrap(); + let _: Action = serde_json::from_str(r#""dont_notify""#).unwrap(); + let _: Action = serde_json::from_str(r#""coalesce""#).unwrap(); + let _: Action = serde_json::from_str(r#"{"set_tweak": "highlight"}"#).unwrap(); +} + +#[test] +fn test_custom_action() { + let json = r#"{"some_custom":"action_fields"}"#; + + let action: Action = serde_json::from_str(json).unwrap(); + assert!(matches!(action, Action::Unknown(_))); + + let new_json = serde_json::to_string(&action).unwrap(); + assert_eq!(json, new_json); +} diff --git a/stubs/synapse/synapse_rust.pyi b/stubs/synapse/synapse_rust.pyi deleted file mode 100644 index 8658d3138f..0000000000 --- a/stubs/synapse/synapse_rust.pyi +++ /dev/null @@ -1,2 +0,0 @@ -def sum_as_string(a: int, b: int) -> str: ... -def get_rust_file_digest() -> str: ... diff --git a/stubs/synapse/synapse_rust/__init__.pyi b/stubs/synapse/synapse_rust/__init__.pyi new file mode 100644 index 0000000000..8658d3138f --- /dev/null +++ b/stubs/synapse/synapse_rust/__init__.pyi @@ -0,0 +1,2 @@ +def sum_as_string(a: int, b: int) -> str: ... +def get_rust_file_digest() -> str: ... diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi new file mode 100644 index 0000000000..93c4e69d42 --- /dev/null +++ b/stubs/synapse/synapse_rust/push.pyi @@ -0,0 +1,37 @@ +from typing import Any, Collection, Dict, Mapping, Sequence, Tuple, Union + +from synapse.types import JsonDict + +class PushRule: + @property + def rule_id(self) -> str: ... + @property + def priority_class(self) -> int: ... + @property + def conditions(self) -> Sequence[Mapping[str, str]]: ... + @property + def actions(self) -> Sequence[Union[Mapping[str, Any], str]]: ... + @property + def default(self) -> bool: ... + @property + def default_enabled(self) -> bool: ... + @staticmethod + def from_db( + rule_id: str, priority_class: int, conditions: str, actions: str + ) -> "PushRule": ... + +class PushRules: + def __init__(self, rules: Collection[PushRule]): ... + def rules(self) -> Collection[PushRule]: ... + +class FilteredPushRules: + def __init__( + self, + push_rules: PushRules, + enabled_map: Dict[str, bool], + msc3786_enabled: bool, + msc3772_enabled: bool, + ): ... + def rules(self) -> Collection[Tuple[PushRule, bool]]: ... + +def get_base_rule_ids() -> Collection[str]: ... diff --git a/synapse/handlers/push_rules.py b/synapse/handlers/push_rules.py index 2599160bcc..1219672a59 100644 --- a/synapse/handlers/push_rules.py +++ b/synapse/handlers/push_rules.py @@ -16,14 +16,17 @@ from typing import TYPE_CHECKING, List, Optional, Union import attr from synapse.api.errors import SynapseError, UnrecognizedRequestError -from synapse.push.baserules import BASE_RULE_IDS from synapse.storage.push_rule import RuleNotFoundException +from synapse.synapse_rust.push import get_base_rule_ids from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer +BASE_RULE_IDS = get_base_rule_ids() + + @attr.s(slots=True, frozen=True, auto_attribs=True) class RuleSpec: scope: str diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py deleted file mode 100644 index 440205e80c..0000000000 --- a/synapse/push/baserules.py +++ /dev/null @@ -1,583 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2017 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Push rules is the system used to determine which events trigger a push (and a -bump in notification counts). - -This consists of a list of "push rules" for each user, where a push rule is a -pair of "conditions" and "actions". When a user receives an event Synapse -iterates over the list of push rules until it finds one where all the conditions -match the event, at which point "actions" describe the outcome (e.g. notify, -highlight, etc). - -Push rules are split up into 5 different "kinds" (aka "priority classes"), which -are run in order: - 1. Override — highest priority rules, e.g. always ignore notices - 2. Content — content specific rules, e.g. @ notifications - 3. Room — per room rules, e.g. enable/disable notifications for all messages - in a room - 4. Sender — per sender rules, e.g. never notify for messages from a given - user - 5. Underride — the lowest priority "default" rules, e.g. notify for every - message. - -The set of "base rules" are the list of rules that every user has by default. A -user can modify their copy of the push rules in one of three ways: - - 1. Adding a new push rule of a certain kind - 2. Changing the actions of a base rule - 3. Enabling/disabling a base rule. - -The base rules are split into whether they come before or after a particular -kind, so the order of push rule evaluation would be: base rules for before -"override" kind, user defined "override" rules, base rules after "override" -kind, etc, etc. -""" - -import itertools -import logging -from typing import Dict, Iterator, List, Mapping, Sequence, Tuple, Union - -import attr - -from synapse.config.experimental import ExperimentalConfig -from synapse.push.rulekinds import PRIORITY_CLASS_MAP - -logger = logging.getLogger(__name__) - - -@attr.s(auto_attribs=True, slots=True, frozen=True) -class PushRule: - """A push rule - - Attributes: - rule_id: a unique ID for this rule - priority_class: what "kind" of push rule this is (see - `PRIORITY_CLASS_MAP` for mapping between int and kind) - conditions: the sequence of conditions that all need to match - actions: the actions to apply if all conditions are met - default: is this a base rule? - default_enabled: is this enabled by default? - """ - - rule_id: str - priority_class: int - conditions: Sequence[Mapping[str, str]] - actions: Sequence[Union[str, Mapping]] - default: bool = False - default_enabled: bool = True - - -@attr.s(auto_attribs=True, slots=True, frozen=True, weakref_slot=False) -class PushRules: - """A collection of push rules for an account. - - Can be iterated over, producing push rules in priority order. - """ - - # A mapping from rule ID to push rule that overrides a base rule. These will - # be returned instead of the base rule. - overriden_base_rules: Dict[str, PushRule] = attr.Factory(dict) - - # The following stores the custom push rules at each priority class. - # - # We keep these separate (rather than combining into one big list) to avoid - # copying the base rules around all the time. - override: List[PushRule] = attr.Factory(list) - content: List[PushRule] = attr.Factory(list) - room: List[PushRule] = attr.Factory(list) - sender: List[PushRule] = attr.Factory(list) - underride: List[PushRule] = attr.Factory(list) - - def __iter__(self) -> Iterator[PushRule]: - # When iterating over the push rules we need to return the base rules - # interspersed at the correct spots. - for rule in itertools.chain( - BASE_PREPEND_OVERRIDE_RULES, - self.override, - BASE_APPEND_OVERRIDE_RULES, - self.content, - BASE_APPEND_CONTENT_RULES, - self.room, - self.sender, - self.underride, - BASE_APPEND_UNDERRIDE_RULES, - ): - # Check if a base rule has been overriden by a custom rule. If so - # return that instead. - override_rule = self.overriden_base_rules.get(rule.rule_id) - if override_rule: - yield override_rule - else: - yield rule - - def __len__(self) -> int: - # The length is mostly used by caches to get a sense of "size" / amount - # of memory this object is using, so we only count the number of custom - # rules. - return ( - len(self.overriden_base_rules) - + len(self.override) - + len(self.content) - + len(self.room) - + len(self.sender) - + len(self.underride) - ) - - -@attr.s(auto_attribs=True, slots=True, frozen=True, weakref_slot=False) -class FilteredPushRules: - """A wrapper around `PushRules` that filters out disabled experimental push - rules, and includes the "enabled" state for each rule when iterated over. - """ - - push_rules: PushRules - enabled_map: Dict[str, bool] - experimental_config: ExperimentalConfig - - def __iter__(self) -> Iterator[Tuple[PushRule, bool]]: - for rule in self.push_rules: - if not _is_experimental_rule_enabled( - rule.rule_id, self.experimental_config - ): - continue - - enabled = self.enabled_map.get(rule.rule_id, rule.default_enabled) - - yield rule, enabled - - def __len__(self) -> int: - return len(self.push_rules) - - -DEFAULT_EMPTY_PUSH_RULES = PushRules() - - -def compile_push_rules(rawrules: List[PushRule]) -> PushRules: - """Given a set of custom push rules return a `PushRules` instance (which - includes the base rules). - """ - - if not rawrules: - # Fast path to avoid allocating empty lists when there are no custom - # rules for the user. - return DEFAULT_EMPTY_PUSH_RULES - - rules = PushRules() - - for rule in rawrules: - # We need to decide which bucket each custom push rule goes into. - - # If it has the same ID as a base rule then it overrides that... - overriden_base_rule = BASE_RULES_BY_ID.get(rule.rule_id) - if overriden_base_rule: - rules.overriden_base_rules[rule.rule_id] = attr.evolve( - overriden_base_rule, actions=rule.actions - ) - continue - - # ... otherwise it gets added to the appropriate priority class bucket - collection: List[PushRule] - if rule.priority_class == 5: - collection = rules.override - elif rule.priority_class == 4: - collection = rules.content - elif rule.priority_class == 3: - collection = rules.room - elif rule.priority_class == 2: - collection = rules.sender - elif rule.priority_class == 1: - collection = rules.underride - elif rule.priority_class <= 0: - logger.info( - "Got rule with priority class less than zero, but doesn't override a base rule: %s", - rule, - ) - continue - else: - # We log and continue here so as not to break event sending - logger.error("Unknown priority class: %", rule.priority_class) - continue - - collection.append(rule) - - return rules - - -def _is_experimental_rule_enabled( - rule_id: str, experimental_config: ExperimentalConfig -) -> bool: - """Used by `FilteredPushRules` to filter out experimental rules when they - have not been enabled. - """ - if ( - rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl" - and not experimental_config.msc3786_enabled - ): - return False - if ( - rule_id == "global/underride/.org.matrix.msc3772.thread_reply" - and not experimental_config.msc3772_enabled - ): - return False - return True - - -BASE_APPEND_CONTENT_RULES = [ - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["content"], - rule_id="global/content/.m.rule.contains_user_name", - conditions=[ - { - "kind": "event_match", - "key": "content.body", - # Match the localpart of the requester's MXID. - "pattern_type": "user_localpart", - } - ], - actions=[ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight"}, - ], - ) -] - - -BASE_PREPEND_OVERRIDE_RULES = [ - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.master", - default_enabled=False, - conditions=[], - actions=["dont_notify"], - ) -] - - -BASE_APPEND_OVERRIDE_RULES = [ - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.suppress_notices", - conditions=[ - { - "kind": "event_match", - "key": "content.msgtype", - "pattern": "m.notice", - "_cache_key": "_suppress_notices", - } - ], - actions=["dont_notify"], - ), - # NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event - # otherwise invites will be matched by .m.rule.member_event - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.invite_for_me", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.member", - "_cache_key": "_member", - }, - { - "kind": "event_match", - "key": "content.membership", - "pattern": "invite", - "_cache_key": "_invite_member", - }, - # Match the requester's MXID. - {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, - ], - actions=[ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight", "value": False}, - ], - ), - # Will we sometimes want to know about people joining and leaving? - # Perhaps: if so, this could be expanded upon. Seems the most usual case - # is that we don't though. We add this override rule so that even if - # the room rule is set to notify, we don't get notifications about - # join/leave/avatar/displayname events. - # See also: https://matrix.org/jira/browse/SYN-607 - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.member_event", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.member", - "_cache_key": "_member", - } - ], - actions=["dont_notify"], - ), - # This was changed from underride to override so it's closer in priority - # to the content rules where the user name highlight rule lives. This - # way a room rule is lower priority than both but a custom override rule - # is higher priority than both. - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.contains_display_name", - conditions=[{"kind": "contains_display_name"}], - actions=[ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight"}, - ], - ), - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.roomnotif", - conditions=[ - { - "kind": "event_match", - "key": "content.body", - "pattern": "@room", - "_cache_key": "_roomnotif_content", - }, - { - "kind": "sender_notification_permission", - "key": "room", - "_cache_key": "_roomnotif_pl", - }, - ], - actions=["notify", {"set_tweak": "highlight", "value": True}], - ), - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.tombstone", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.tombstone", - "_cache_key": "_tombstone", - }, - { - "kind": "event_match", - "key": "state_key", - "pattern": "", - "_cache_key": "_tombstone_statekey", - }, - ], - actions=["notify", {"set_tweak": "highlight", "value": True}], - ), - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.m.rule.reaction", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.reaction", - "_cache_key": "_reaction", - } - ], - actions=["dont_notify"], - ), - # XXX: This is an experimental rule that is only enabled if msc3786_enabled - # is enabled, if it is not the rule gets filtered out in _load_rules() in - # PushRulesWorkerStore - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["override"], - rule_id="global/override/.org.matrix.msc3786.rule.room.server_acl", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.server_acl", - "_cache_key": "_room_server_acl", - }, - { - "kind": "event_match", - "key": "state_key", - "pattern": "", - "_cache_key": "_room_server_acl_state_key", - }, - ], - actions=[], - ), -] - - -BASE_APPEND_UNDERRIDE_RULES = [ - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.m.rule.call", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.call.invite", - "_cache_key": "_call", - } - ], - actions=[ - "notify", - {"set_tweak": "sound", "value": "ring"}, - {"set_tweak": "highlight", "value": False}, - ], - ), - # XXX: once m.direct is standardised everywhere, we should use it to detect - # a DM from the user's perspective rather than this heuristic. - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.m.rule.room_one_to_one", - conditions=[ - {"kind": "room_member_count", "is": "2", "_cache_key": "member_count"}, - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.message", - "_cache_key": "_message", - }, - ], - actions=[ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight", "value": False}, - ], - ), - # XXX: this is going to fire for events which aren't m.room.messages - # but are encrypted (e.g. m.call.*)... - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.m.rule.encrypted_room_one_to_one", - conditions=[ - {"kind": "room_member_count", "is": "2", "_cache_key": "member_count"}, - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.encrypted", - "_cache_key": "_encrypted", - }, - ], - actions=[ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight", "value": False}, - ], - ), - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.org.matrix.msc3772.thread_reply", - conditions=[ - { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.thread", - # Match the requester's MXID. - "sender_type": "user_id", - } - ], - actions=["notify", {"set_tweak": "highlight", "value": False}], - ), - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.m.rule.message", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.message", - "_cache_key": "_message", - } - ], - actions=["notify", {"set_tweak": "highlight", "value": False}], - ), - # XXX: this is going to fire for events which aren't m.room.messages - # but are encrypted (e.g. m.call.*)... - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.m.rule.encrypted", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.encrypted", - "_cache_key": "_encrypted", - } - ], - actions=["notify", {"set_tweak": "highlight", "value": False}], - ), - PushRule( - default=True, - priority_class=PRIORITY_CLASS_MAP["underride"], - rule_id="global/underride/.im.vector.jitsi", - conditions=[ - { - "kind": "event_match", - "key": "type", - "pattern": "im.vector.modular.widgets", - "_cache_key": "_type_modular_widgets", - }, - { - "kind": "event_match", - "key": "content.type", - "pattern": "jitsi", - "_cache_key": "_content_type_jitsi", - }, - { - "kind": "event_match", - "key": "state_key", - "pattern": "*", - "_cache_key": "_is_state_event", - }, - ], - actions=["notify", {"set_tweak": "highlight", "value": False}], - ), -] - - -BASE_RULE_IDS = set() - -BASE_RULES_BY_ID: Dict[str, PushRule] = {} - -for r in BASE_APPEND_CONTENT_RULES: - BASE_RULE_IDS.add(r.rule_id) - BASE_RULES_BY_ID[r.rule_id] = r - -for r in BASE_PREPEND_OVERRIDE_RULES: - BASE_RULE_IDS.add(r.rule_id) - BASE_RULES_BY_ID[r.rule_id] = r - -for r in BASE_APPEND_OVERRIDE_RULES: - BASE_RULE_IDS.add(r.rule_id) - BASE_RULES_BY_ID[r.rule_id] = r - -for r in BASE_APPEND_UNDERRIDE_RULES: - BASE_RULE_IDS.add(r.rule_id) - BASE_RULES_BY_ID[r.rule_id] = r diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 3846fbc5f0..404379ef67 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -37,11 +37,11 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.state import StateFilter +from synapse.synapse_rust.push import FilteredPushRules, PushRule from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state -from .baserules import FilteredPushRules, PushRule from .push_rule_evaluator import PushRuleEvaluatorForEvent if TYPE_CHECKING: @@ -280,7 +280,8 @@ class BulkPushRuleEvaluator: thread_id = "main" if relation: relations = await self._get_mutual_relations( - relation.parent_id, itertools.chain(*rules_by_user.values()) + relation.parent_id, + itertools.chain(*(r.rules() for r in rules_by_user.values())), ) if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id @@ -333,7 +334,7 @@ class BulkPushRuleEvaluator: # current user, it'll be added to the dict later. actions_by_user[uid] = [] - for rule, enabled in rules: + for rule, enabled in rules.rules(): if not enabled: continue diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 73618d9234..ebc13beda1 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -16,10 +16,9 @@ import copy from typing import Any, Dict, List, Optional from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP +from synapse.synapse_rust.push import FilteredPushRules, PushRule from synapse.types import UserID -from .baserules import FilteredPushRules, PushRule - def format_push_rules_for_user( user: UserID, ruleslist: FilteredPushRules @@ -34,7 +33,7 @@ def format_push_rules_for_user( rules["global"] = _add_empty_priority_class_arrays(rules["global"]) - for r, enabled in ruleslist: + for r, enabled in ruleslist.rules(): template_name = _priority_class_to_template_name(r.priority_class) rulearray = rules["global"][template_name] diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 5079edd1e0..ed17b2e70c 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -30,9 +30,8 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig -from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -51,6 +50,7 @@ from synapse.storage.util.id_generators import ( IdGenerator, StreamIdGenerator, ) +from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -72,18 +72,25 @@ def _load_rules( """ ruleslist = [ - PushRule( + PushRule.from_db( rule_id=rawrule["rule_id"], priority_class=rawrule["priority_class"], - conditions=db_to_json(rawrule["conditions"]), - actions=db_to_json(rawrule["actions"]), + conditions=rawrule["conditions"], + actions=rawrule["actions"], ) for rawrule in rawrules ] - push_rules = compile_push_rules(ruleslist) + push_rules = PushRules( + ruleslist, + ) - filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config) + filtered_rules = FilteredPushRules( + push_rules, + enabled_map, + msc3786_enabled=experimental_config.msc3786_enabled, + msc3772_enabled=experimental_config.msc3772_enabled, + ) return filtered_rules @@ -845,7 +852,7 @@ class PushRuleStore(PushRulesWorkerStore): user_push_rules = await self.get_push_rules_for_user(user_id) # Get rules relating to the old room and copy them to the new room - for rule, enabled in user_push_rules: + for rule, enabled in user_push_rules.rules(): if not enabled: continue diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index 7b9b711521..bce65fab7d 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -15,11 +15,11 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import AccountDataTypes -from synapse.push.baserules import PushRule from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest import admin from synapse.rest.client import account, login from synapse.server import HomeServer +from synapse.synapse_rust.push import PushRule from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -161,20 +161,15 @@ class DeactivateAccountTestCase(HomeserverTestCase): self._store.get_push_rules_for_user(self.user) ) # Filter out default rules; we don't care - push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)] + push_rules = [ + r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r) + ] # Check our rule made it - self.assertEqual( - push_rules, - [ - PushRule( - rule_id="personal.override.rule1", - priority_class=5, - conditions=[], - actions=[], - ) - ], - push_rules, - ) + self.assertEqual(len(push_rules), 1) + self.assertEqual(push_rules[0].rule_id, "personal.override.rule1") + self.assertEqual(push_rules[0].priority_class, 5) + self.assertEqual(push_rules[0].conditions, []) + self.assertEqual(push_rules[0].actions, []) # Request the deactivation of our account self._deactivate_my_account() @@ -183,7 +178,9 @@ class DeactivateAccountTestCase(HomeserverTestCase): self._store.get_push_rules_for_user(self.user) ) # Filter out default rules; we don't care - push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)] + push_rules = [ + r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r) + ] # Check our rule no longer exists self.assertEqual(push_rules, [], push_rules) -- cgit 1.5.1 From 6caa3030835f879724c003a5b0dc66a6285451d8 Mon Sep 17 00:00:00 2001 From: Kateřina Churanová Date: Wed, 28 Sep 2022 14:31:53 +0200 Subject: fix: Push notifications for invite over federation (#13719) --- changelog.d/13719.bugfix | 1 + synapse/events/__init__.py | 4 ++++ synapse/handlers/federation.py | 13 ++++++++++--- synapse/handlers/federation_event.py | 1 + synapse/push/bulk_push_rule_evaluator.py | 10 +++++++--- synapse/push/push_rule_evaluator.py | 16 ++++++++-------- synapse/storage/controllers/persist_events.py | 10 ++++++---- synapse/storage/databases/main/events.py | 10 +++++----- 8 files changed, 42 insertions(+), 23 deletions(-) create mode 100644 changelog.d/13719.bugfix (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/changelog.d/13719.bugfix b/changelog.d/13719.bugfix new file mode 100644 index 0000000000..4318f4daff --- /dev/null +++ b/changelog.d/13719.bugfix @@ -0,0 +1 @@ +Send invite push notifications for invite over federation. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index b2c9119fd0..030c3ca408 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -289,6 +289,10 @@ class _EventInternalMetadata: """ return self._dict.get("historical", False) + def is_notifiable(self) -> bool: + """Whether this event can trigger a push notification""" + return not self.is_outlier() or self.is_out_of_band_membership() + class EventBase(metaclass=abc.ABCMeta): @property diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 74580f60df..8f847ff845 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -149,6 +149,7 @@ class FederationHandler: self.http_client = hs.get_proxied_blacklisted_http_client() self._replication = hs.get_replication_data_handler() self._federation_event_handler = hs.get_federation_event_handler() + self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator() self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( hs @@ -956,9 +957,15 @@ class FederationHandler: ) context = EventContext.for_outlier(self._storage_controllers) - await self._federation_event_handler.persist_events_and_notify( - event.room_id, [(event, context)] - ) + + await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context) + try: + await self._federation_event_handler.persist_events_and_notify( + event.room_id, [(event, context)] + ) + except Exception: + await self.store.remove_push_actions_from_staging(event.event_id) + raise return event diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 2d7cde7506..3fac256881 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2170,6 +2170,7 @@ class FederationEventHandler: if instance != self._instance_name: # Limit the number of events sent over replication. We choose 200 # here as that is what we default to in `max_request_body_size(..)` + result = {} try: for batch in batch_iter(event_and_contexts, 200): result = await self._send_events( diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 404379ef67..32313e3bcf 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -173,7 +173,11 @@ class BulkPushRuleEvaluator: async def _get_power_levels_and_sender_level( self, event: EventBase, context: EventContext - ) -> Tuple[dict, int]: + ) -> Tuple[dict, Optional[int]]: + # There are no power levels and sender levels possible to get from outlier + if event.internal_metadata.is_outlier(): + return {}, None + event_types = auth_types_for_event(event.room_version, event) prev_state_ids = await context.get_prev_state_ids( StateFilter.from_types(event_types) @@ -250,8 +254,8 @@ class BulkPushRuleEvaluator: should increment the unread count, and insert the results into the event_push_actions_staging table. """ - if event.internal_metadata.is_outlier(): - # This can happen due to out of band memberships + if not event.internal_metadata.is_notifiable(): + # Push rules for events that aren't notifiable can't be processed by this return # Disable counting as unread unless the experimental configuration is diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 3c5632cd91..f8176c5a42 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -42,18 +42,18 @@ IS_GLOB = re.compile(r"[\?\*\[\]]") INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") -def _room_member_count( - ev: EventBase, condition: Mapping[str, Any], room_member_count: int -) -> bool: +def _room_member_count(condition: Mapping[str, Any], room_member_count: int) -> bool: return _test_ineq_condition(condition, room_member_count) def _sender_notification_permission( - ev: EventBase, condition: Mapping[str, Any], - sender_power_level: int, + sender_power_level: Optional[int], power_levels: Dict[str, Union[int, Dict[str, int]]], ) -> bool: + if sender_power_level is None: + return False + notif_level_key = condition.get("key") if notif_level_key is None: return False @@ -129,7 +129,7 @@ class PushRuleEvaluatorForEvent: self, event: EventBase, room_member_count: int, - sender_power_level: int, + sender_power_level: Optional[int], power_levels: Dict[str, Union[int, Dict[str, int]]], relations: Dict[str, Set[Tuple[str, str]]], relations_match_enabled: bool, @@ -198,10 +198,10 @@ class PushRuleEvaluatorForEvent: elif condition["kind"] == "contains_display_name": return self._contains_display_name(display_name) elif condition["kind"] == "room_member_count": - return _room_member_count(self._event, condition, self._room_member_count) + return _room_member_count(condition, self._room_member_count) elif condition["kind"] == "sender_notification_permission": return _sender_notification_permission( - self._event, condition, self._sender_power_level, self._power_levels + condition, self._sender_power_level, self._power_levels ) elif ( condition["kind"] == "org.matrix.msc3772.relation_match" diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 709cb792ed..06e71a8053 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -423,16 +423,18 @@ class EventsPersistenceStorageController: for d in ret_vals: replaced_events.update(d) - events = [] + persisted_events = [] for event, _ in events_and_contexts: existing_event_id = replaced_events.get(event.event_id) if existing_event_id: - events.append(await self.main_store.get_event(existing_event_id)) + persisted_events.append( + await self.main_store.get_event(existing_event_id) + ) else: - events.append(event) + persisted_events.append(event) return ( - events, + persisted_events, self.main_store.get_room_max_token(), ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b59eb7478b..bb489b8189 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2134,13 +2134,13 @@ class PersistEventsStore: appear in events_and_context. """ - # Only non outlier events will have push actions associated with them, + # Only notifiable events will have push actions associated with them, # so let's filter them out. (This makes joining large rooms faster, as # these queries took seconds to process all the state events). - non_outlier_events = [ + notifiable_events = [ event for event, _ in events_and_contexts - if not event.internal_metadata.is_outlier() + if event.internal_metadata.is_notifiable() ] sql = """ @@ -2153,7 +2153,7 @@ class PersistEventsStore: WHERE event_id = ? """ - if non_outlier_events: + if notifiable_events: txn.execute_batch( sql, ( @@ -2163,7 +2163,7 @@ class PersistEventsStore: event.depth, event.event_id, ) - for event in non_outlier_events + for event in notifiable_events ), ) -- cgit 1.5.1 From ebd9e2dac6495a1857617d1a76c9259a988f8bb4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 29 Sep 2022 16:12:09 +0100 Subject: Implement push rule evaluation in Rust. (#13838) --- changelog.d/13838.misc | 1 + rust/Cargo.toml | 4 +- rust/benches/evaluator.rs | 149 ++++++++++++ rust/benches/glob.rs | 40 ++++ rust/build.rs | 2 +- rust/src/push/base_rules.rs | 1 + rust/src/push/evaluator.rs | 374 +++++++++++++++++++++++++++++++ rust/src/push/mod.rs | 28 ++- rust/src/push/utils.rs | 215 ++++++++++++++++++ stubs/synapse/synapse_rust/push.pyi | 19 +- synapse/push/bulk_push_rule_evaluator.py | 44 ++-- synapse/push/httppusher.py | 39 +++- synapse/push/push_rule_evaluator.py | 361 ----------------------------- tests/push/test_push_rule_evaluator.py | 20 +- 14 files changed, 894 insertions(+), 403 deletions(-) create mode 100644 changelog.d/13838.misc create mode 100644 rust/benches/evaluator.rs create mode 100644 rust/benches/glob.rs create mode 100644 rust/src/push/evaluator.rs create mode 100644 rust/src/push/utils.rs delete mode 100644 synapse/push/push_rule_evaluator.py (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/changelog.d/13838.misc b/changelog.d/13838.misc new file mode 100644 index 0000000000..28bddb7059 --- /dev/null +++ b/changelog.d/13838.misc @@ -0,0 +1 @@ +Port push rules to using Rust. diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 44263bf77e..cffaa5b51b 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -11,7 +11,9 @@ rust-version = "1.58.1" [lib] name = "synapse" -crate-type = ["cdylib"] +# We generate a `cdylib` for Python and a standard `lib` for running +# tests/benchmarks. +crate-type = ["lib", "cdylib"] [package.metadata.maturin] # This is where we tell maturin where to place the built library. diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs new file mode 100644 index 0000000000..ed411461d1 --- /dev/null +++ b/rust/benches/evaluator.rs @@ -0,0 +1,149 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![feature(test)] +use synapse::push::{ + evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules, +}; +use test::Bencher; + +extern crate test; + +#[bench] +fn bench_match_exact(b: &mut Bencher) { + let flattened_keys = [ + ("type".to_string(), "m.text".to_string()), + ("room_id".to_string(), "!room:server".to_string()), + ("content.body".to_string(), "test message".to_string()), + ] + .into_iter() + .collect(); + + let eval = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + Default::default(), + Default::default(), + true, + ) + .unwrap(); + + let condition = Condition::Known(synapse::push::KnownCondition::EventMatch( + EventMatchCondition { + key: "room_id".into(), + pattern: Some("!room:server".into()), + pattern_type: None, + }, + )); + + let matched = eval.match_condition(&condition, None, None).unwrap(); + assert!(matched, "Didn't match"); + + b.iter(|| eval.match_condition(&condition, None, None).unwrap()); +} + +#[bench] +fn bench_match_word(b: &mut Bencher) { + let flattened_keys = [ + ("type".to_string(), "m.text".to_string()), + ("room_id".to_string(), "!room:server".to_string()), + ("content.body".to_string(), "test message".to_string()), + ] + .into_iter() + .collect(); + + let eval = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + Default::default(), + Default::default(), + true, + ) + .unwrap(); + + let condition = Condition::Known(synapse::push::KnownCondition::EventMatch( + EventMatchCondition { + key: "content.body".into(), + pattern: Some("test".into()), + pattern_type: None, + }, + )); + + let matched = eval.match_condition(&condition, None, None).unwrap(); + assert!(matched, "Didn't match"); + + b.iter(|| eval.match_condition(&condition, None, None).unwrap()); +} + +#[bench] +fn bench_match_word_miss(b: &mut Bencher) { + let flattened_keys = [ + ("type".to_string(), "m.text".to_string()), + ("room_id".to_string(), "!room:server".to_string()), + ("content.body".to_string(), "test message".to_string()), + ] + .into_iter() + .collect(); + + let eval = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + Default::default(), + Default::default(), + true, + ) + .unwrap(); + + let condition = Condition::Known(synapse::push::KnownCondition::EventMatch( + EventMatchCondition { + key: "content.body".into(), + pattern: Some("foobar".into()), + pattern_type: None, + }, + )); + + let matched = eval.match_condition(&condition, None, None).unwrap(); + assert!(!matched, "Didn't match"); + + b.iter(|| eval.match_condition(&condition, None, None).unwrap()); +} + +#[bench] +fn bench_eval_message(b: &mut Bencher) { + let flattened_keys = [ + ("type".to_string(), "m.text".to_string()), + ("room_id".to_string(), "!room:server".to_string()), + ("content.body".to_string(), "test message".to_string()), + ] + .into_iter() + .collect(); + + let eval = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + Default::default(), + Default::default(), + true, + ) + .unwrap(); + + let rules = + FilteredPushRules::py_new(PushRules::new(Vec::new()), Default::default(), false, false); + + b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); +} diff --git a/rust/benches/glob.rs b/rust/benches/glob.rs new file mode 100644 index 0000000000..b6697d9285 --- /dev/null +++ b/rust/benches/glob.rs @@ -0,0 +1,40 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![feature(test)] + +use synapse::push::utils::{glob_to_regex, GlobMatchType}; +use test::Bencher; + +extern crate test; + +#[bench] +fn bench_whole(b: &mut Bencher) { + b.iter(|| glob_to_regex("test", GlobMatchType::Whole)); +} + +#[bench] +fn bench_word(b: &mut Bencher) { + b.iter(|| glob_to_regex("test", GlobMatchType::Word)); +} + +#[bench] +fn bench_whole_wildcard_run(b: &mut Bencher) { + b.iter(|| glob_to_regex("test***??*?*?foo", GlobMatchType::Whole)); +} + +#[bench] +fn bench_word_wildcard_run(b: &mut Bencher) { + b.iter(|| glob_to_regex("test***??*?*?foo", GlobMatchType::Whole)); +} diff --git a/rust/build.rs b/rust/build.rs index 2117975e56..ef370e6b41 100644 --- a/rust/build.rs +++ b/rust/build.rs @@ -22,7 +22,7 @@ fn main() -> Result<(), std::io::Error> { for entry in entries { if entry.is_dir() { - dirs.push(entry) + dirs.push(entry); } else { paths.push(entry.to_str().expect("valid rust paths").to_string()); } diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 7c62bc4849..bb59676bde 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -262,6 +262,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ priority_class: 1, conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch { rel_type: Cow::Borrowed("m.thread"), + event_type_pattern: None, sender: None, sender_type: Some(Cow::Borrowed("user_id")), })]), diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs new file mode 100644 index 0000000000..efe88ec76e --- /dev/null +++ b/rust/src/push/evaluator.rs @@ -0,0 +1,374 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + borrow::Cow, + collections::{BTreeMap, BTreeSet}, +}; + +use anyhow::{Context, Error}; +use lazy_static::lazy_static; +use log::warn; +use pyo3::prelude::*; +use regex::Regex; + +use super::{ + utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType}, + Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition, +}; + +lazy_static! { + /// Used to parse the `is` clause in the room member count condition. + static ref INEQUALITY_EXPR: Regex = Regex::new(r"^([=<>]*)([0-9]+)$").expect("valid regex"); +} + +/// Allows running a set of push rules against a particular event. +#[pyclass] +pub struct PushRuleEvaluator { + /// A mapping of "flattened" keys to string values in the event, e.g. + /// includes things like "type" and "content.msgtype". + flattened_keys: BTreeMap, + + /// The "content.body", if any. + body: String, + + /// The number of users in the room. + room_member_count: u64, + + /// The `notifications` section of the current power levels in the room. + notification_power_levels: BTreeMap, + + /// The relations related to the event as a mapping from relation type to + /// set of sender/event type 2-tuples. + relations: BTreeMap>, + + /// Is running "relation" conditions enabled? + relation_match_enabled: bool, + + /// The power level of the sender of the event, or None if event is an + /// outlier. + sender_power_level: Option, +} + +#[pymethods] +impl PushRuleEvaluator { + /// Create a new `PushRuleEvaluator`. See struct docstring for details. + #[new] + pub fn py_new( + flattened_keys: BTreeMap, + room_member_count: u64, + sender_power_level: Option, + notification_power_levels: BTreeMap, + relations: BTreeMap>, + relation_match_enabled: bool, + ) -> Result { + let body = flattened_keys + .get("content.body") + .cloned() + .unwrap_or_default(); + + Ok(PushRuleEvaluator { + flattened_keys, + body, + room_member_count, + notification_power_levels, + relations, + relation_match_enabled, + sender_power_level, + }) + } + + /// Run the evaluator with the given push rules, for the given user ID and + /// display name of the user. + /// + /// Passing in None will skip evaluating rules matching user ID and display + /// name. + /// + /// Returns the set of actions, if any, that match (filtering out any + /// `dont_notify` actions). + pub fn run( + &self, + push_rules: &FilteredPushRules, + user_id: Option<&str>, + display_name: Option<&str>, + ) -> Vec { + 'outer: for (push_rule, enabled) in push_rules.iter() { + if !enabled { + continue; + } + + for condition in push_rule.conditions.iter() { + match self.match_condition(condition, user_id, display_name) { + Ok(true) => {} + Ok(false) => continue 'outer, + Err(err) => { + warn!("Condition match failed {err}"); + continue 'outer; + } + } + } + + let actions = push_rule + .actions + .iter() + // Filter out "dont_notify" actions, as we don't store them. + .filter(|a| **a != Action::DontNotify) + .cloned() + .collect(); + + return actions; + } + + Vec::new() + } + + /// Check if the given condition matches. + fn matches( + &self, + condition: Condition, + user_id: Option<&str>, + display_name: Option<&str>, + ) -> bool { + match self.match_condition(&condition, user_id, display_name) { + Ok(true) => true, + Ok(false) => false, + Err(err) => { + warn!("Condition match failed {err}"); + false + } + } + } +} + +impl PushRuleEvaluator { + /// Match a given `Condition` for a push rule. + pub fn match_condition( + &self, + condition: &Condition, + user_id: Option<&str>, + display_name: Option<&str>, + ) -> Result { + let known_condition = match condition { + Condition::Known(known) => known, + Condition::Unknown(_) => { + return Ok(false); + } + }; + + let result = match known_condition { + KnownCondition::EventMatch(event_match) => { + self.match_event_match(event_match, user_id)? + } + KnownCondition::ContainsDisplayName => { + if let Some(dn) = display_name { + if !dn.is_empty() { + get_glob_matcher(dn, GlobMatchType::Word)?.is_match(&self.body)? + } else { + // We specifically ignore empty display names, as otherwise + // they would always match. + false + } + } else { + false + } + } + KnownCondition::RoomMemberCount { is } => { + if let Some(is) = is { + self.match_member_count(is)? + } else { + false + } + } + KnownCondition::SenderNotificationPermission { key } => { + if let Some(sender_power_level) = &self.sender_power_level { + let required_level = self + .notification_power_levels + .get(key.as_ref()) + .copied() + .unwrap_or(50); + + *sender_power_level >= required_level + } else { + false + } + } + KnownCondition::RelationMatch { + rel_type, + event_type_pattern, + sender, + sender_type, + } => { + self.match_relations(rel_type, sender, sender_type, user_id, event_type_pattern)? + } + }; + + Ok(result) + } + + /// Evaluates a relation condition. + fn match_relations( + &self, + rel_type: &str, + sender: &Option>, + sender_type: &Option>, + user_id: Option<&str>, + event_type_pattern: &Option>, + ) -> Result { + // First check if relation matching is enabled... + if !self.relation_match_enabled { + return Ok(false); + } + + // ... and if there are any relations to match against. + let relations = if let Some(relations) = self.relations.get(rel_type) { + relations + } else { + return Ok(false); + }; + + // Extract the sender pattern from the condition + let sender_pattern = if let Some(sender) = sender { + Some(sender.as_ref()) + } else if let Some(sender_type) = sender_type { + if sender_type == "user_id" { + if let Some(user_id) = user_id { + Some(user_id) + } else { + return Ok(false); + } + } else { + warn!("Unrecognized sender_type: {sender_type}"); + return Ok(false); + } + } else { + None + }; + + let mut sender_compiled_pattern = if let Some(pattern) = sender_pattern { + Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) + } else { + None + }; + + let mut type_compiled_pattern = if let Some(pattern) = event_type_pattern { + Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) + } else { + None + }; + + for (relation_sender, event_type) in relations { + if let Some(pattern) = &mut sender_compiled_pattern { + if !pattern.is_match(relation_sender)? { + continue; + } + } + + if let Some(pattern) = &mut type_compiled_pattern { + if !pattern.is_match(event_type)? { + continue; + } + } + + return Ok(true); + } + + Ok(false) + } + + /// Evaluates a `event_match` condition. + fn match_event_match( + &self, + event_match: &EventMatchCondition, + user_id: Option<&str>, + ) -> Result { + let pattern = if let Some(pattern) = &event_match.pattern { + pattern + } else if let Some(pattern_type) = &event_match.pattern_type { + // The `pattern_type` can either be "user_id" or "user_localpart", + // either way if we don't have a `user_id` then the condition can't + // match. + let user_id = if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + }; + + match &**pattern_type { + "user_id" => user_id, + "user_localpart" => get_localpart_from_id(user_id)?, + _ => return Ok(false), + } + } else { + return Ok(false); + }; + + let haystack = if let Some(haystack) = self.flattened_keys.get(&*event_match.key) { + haystack + } else { + return Ok(false); + }; + + // For the content.body we match against "words", but for everything + // else we match against the entire value. + let match_type = if event_match.key == "content.body" { + GlobMatchType::Word + } else { + GlobMatchType::Whole + }; + + let mut compiled_pattern = get_glob_matcher(pattern, match_type)?; + compiled_pattern.is_match(haystack) + } + + /// Match the member count against an 'is' condition + /// The `is` condition can be things like '>2', '==3' or even just '4'. + fn match_member_count(&self, is: &str) -> Result { + let captures = INEQUALITY_EXPR.captures(is).context("bad 'is' clause")?; + let ineq = captures.get(1).map_or("==", |m| m.as_str()); + let rhs: u64 = captures + .get(2) + .context("missing number")? + .as_str() + .parse()?; + + let matches = match ineq { + "" | "==" => self.room_member_count == rhs, + "<" => self.room_member_count < rhs, + ">" => self.room_member_count > rhs, + ">=" => self.room_member_count >= rhs, + "<=" => self.room_member_count <= rhs, + _ => false, + }; + + Ok(matches) + } +} + +#[test] +fn push_rule_evaluator() { + let mut flattened_keys = BTreeMap::new(); + flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); + let evaluator = PushRuleEvaluator::py_new( + flattened_keys, + 10, + Some(0), + BTreeMap::new(), + BTreeMap::new(), + true, + ) + .unwrap(); + + let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); + assert_eq!(result.len(), 3); +} diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index de6764e7c5..30fffc31ad 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -42,7 +42,6 @@ //! //! The set of "base rules" are the list of rules that every user has by default. A //! user can modify their copy of the push rules in one of three ways: -//! //! 1. Adding a new push rule of a certain kind //! 2. Changing the actions of a base rule //! 3. Enabling/disabling a base rule. @@ -58,12 +57,16 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use anyhow::{Context, Error}; use log::warn; use pyo3::prelude::*; -use pythonize::pythonize; +use pythonize::{depythonize, pythonize}; use serde::de::Error as _; use serde::{Deserialize, Serialize}; use serde_json::Value; +use self::evaluator::PushRuleEvaluator; + mod base_rules; +pub mod evaluator; +pub mod utils; /// Called when registering modules with python. pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { @@ -71,6 +74,7 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { child_module.add_class::()?; child_module.add_class::()?; child_module.add_class::()?; + child_module.add_class::()?; child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?; m.add_submodule(child_module)?; @@ -274,6 +278,8 @@ pub enum KnownCondition { #[serde(rename = "org.matrix.msc3772.relation_match")] RelationMatch { rel_type: Cow<'static, str>, + #[serde(skip_serializing_if = "Option::is_none", rename = "type")] + event_type_pattern: Option>, #[serde(skip_serializing_if = "Option::is_none")] sender: Option>, #[serde(skip_serializing_if = "Option::is_none")] @@ -287,20 +293,26 @@ impl IntoPy for Condition { } } +impl<'source> FromPyObject<'source> for Condition { + fn extract(ob: &'source PyAny) -> PyResult { + Ok(depythonize(ob)?) + } +} + /// The body of a [`Condition::EventMatch`] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct EventMatchCondition { - key: Cow<'static, str>, + pub key: Cow<'static, str>, #[serde(skip_serializing_if = "Option::is_none")] - pattern: Option>, + pub pattern: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pattern_type: Option>, + pub pattern_type: Option>, } /// The collection of push rules for a user. #[derive(Debug, Clone, Default)] #[pyclass(frozen)] -struct PushRules { +pub struct PushRules { /// Custom push rules that override a base rule. overridden_base_rules: HashMap, PushRule>, @@ -319,7 +331,7 @@ struct PushRules { #[pymethods] impl PushRules { #[new] - fn new(rules: Vec) -> PushRules { + pub fn new(rules: Vec) -> PushRules { let mut push_rules: PushRules = Default::default(); for rule in rules { @@ -396,7 +408,7 @@ pub struct FilteredPushRules { #[pymethods] impl FilteredPushRules { #[new] - fn py_new( + pub fn py_new( push_rules: PushRules, enabled_map: BTreeMap, msc3786_enabled: bool, diff --git a/rust/src/push/utils.rs b/rust/src/push/utils.rs new file mode 100644 index 0000000000..8759340473 --- /dev/null +++ b/rust/src/push/utils.rs @@ -0,0 +1,215 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::bail; +use anyhow::Context; +use anyhow::Error; +use lazy_static::lazy_static; +use regex; +use regex::Regex; +use regex::RegexBuilder; + +lazy_static! { + /// Matches runs of non-wildcard characters followed by wildcard characters. + static ref WILDCARD_RUN: Regex = Regex::new(r"([^\?\*]*)([\?\*]*)").expect("valid regex"); +} + +/// Extract the localpart from a Matrix style ID +pub(crate) fn get_localpart_from_id(id: &str) -> Result<&str, Error> { + let (localpart, _) = id + .split_once(':') + .with_context(|| format!("ID does not contain colon: {id}"))?; + + // We need to strip off the first character, which is the ID type. + if localpart.is_empty() { + bail!("Invalid ID {id}"); + } + + Ok(&localpart[1..]) +} + +/// Used by `glob_to_regex` to specify what to match the regex against. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum GlobMatchType { + /// The generated regex will match against the entire input. + Whole, + /// The generated regex will match against words. + Word, +} + +/// Convert a "glob" style expression to a regex, anchoring either to the entire +/// input or to individual words. +pub fn glob_to_regex(glob: &str, match_type: GlobMatchType) -> Result { + let mut chunks = Vec::new(); + + // Patterns with wildcards must be simplified to avoid performance cliffs + // - The glob `?**?**?` is equivalent to the glob `???*` + // - The glob `???*` is equivalent to the regex `.{3,}` + for captures in WILDCARD_RUN.captures_iter(glob) { + if let Some(chunk) = captures.get(1) { + chunks.push(regex::escape(chunk.as_str())); + } + + if let Some(wildcards) = captures.get(2) { + if wildcards.as_str() == "" { + continue; + } + + let question_marks = wildcards.as_str().chars().filter(|c| *c == '?').count(); + + if wildcards.as_str().contains('*') { + chunks.push(format!(".{{{question_marks},}}")); + } else { + chunks.push(format!(".{{{question_marks}}}")); + } + } + } + + let joined = chunks.join(""); + + let regex_str = match match_type { + GlobMatchType::Whole => format!(r"\A{joined}\z"), + + // `^|\W` and `\W|$` handle the case where `pattern` starts or ends with a non-word + // character. + GlobMatchType::Word => format!(r"(?:^|\b|\W){joined}(?:\b|\W|$)"), + }; + + Ok(RegexBuilder::new(®ex_str) + .case_insensitive(true) + .build()?) +} + +/// Compiles the glob into a `Matcher`. +pub fn get_glob_matcher(glob: &str, match_type: GlobMatchType) -> Result { + // There are a number of shortcuts we can make if the glob doesn't contain a + // wild card. + let matcher = if glob.contains(['*', '?']) { + let regex = glob_to_regex(glob, match_type)?; + Matcher::Regex(regex) + } else if match_type == GlobMatchType::Whole { + // If there aren't any wildcards and we're matching the whole thing, + // then we simply can do a case-insensitive string match. + Matcher::Whole(glob.to_lowercase()) + } else { + // Otherwise, if we're matching against words then can first check + // if the haystack contains the glob at all. + Matcher::Word { + word: glob.to_lowercase(), + regex: None, + } + }; + + Ok(matcher) +} + +/// Matches against a glob +pub enum Matcher { + /// Plain regex matching. + Regex(Regex), + + /// Case-insensitive equality. + Whole(String), + + /// Word matching. `regex` is a cache of calling [`glob_to_regex`] on word. + Word { word: String, regex: Option }, +} + +impl Matcher { + /// Checks if the glob matches the given haystack. + pub fn is_match(&mut self, haystack: &str) -> Result { + // We want to to do case-insensitive matching, so we convert to + // lowercase first. + let haystack = haystack.to_lowercase(); + + match self { + Matcher::Regex(regex) => Ok(regex.is_match(&haystack)), + Matcher::Whole(whole) => Ok(whole == &haystack), + Matcher::Word { word, regex } => { + // If we're looking for a literal word, then we first check if + // the haystack contains the word as a substring. + if !haystack.contains(&*word) { + return Ok(false); + } + + // If it does contain the word as a substring, then we need to + // check if it is an actual word by testing it against the regex. + let regex = if let Some(regex) = regex { + regex + } else { + let compiled_regex = glob_to_regex(word, GlobMatchType::Word)?; + regex.insert(compiled_regex) + }; + + Ok(regex.is_match(&haystack)) + } + } + } +} + +#[test] +fn test_get_domain_from_id() { + get_localpart_from_id("").unwrap_err(); + get_localpart_from_id(":").unwrap_err(); + get_localpart_from_id(":asd").unwrap_err(); + get_localpart_from_id("::as::asad").unwrap_err(); + + assert_eq!(get_localpart_from_id("@test:foo").unwrap(), "test"); + assert_eq!(get_localpart_from_id("@:").unwrap(), ""); + assert_eq!(get_localpart_from_id("@test:foo:907").unwrap(), "test"); +} + +#[test] +fn tset_glob() -> Result<(), Error> { + assert_eq!( + glob_to_regex("simple", GlobMatchType::Whole)?.as_str(), + r"\Asimple\z" + ); + assert_eq!( + glob_to_regex("simple*", GlobMatchType::Whole)?.as_str(), + r"\Asimple.{0,}\z" + ); + assert_eq!( + glob_to_regex("simple?", GlobMatchType::Whole)?.as_str(), + r"\Asimple.{1}\z" + ); + assert_eq!( + glob_to_regex("simple?*?*", GlobMatchType::Whole)?.as_str(), + r"\Asimple.{2,}\z" + ); + assert_eq!( + glob_to_regex("simple???", GlobMatchType::Whole)?.as_str(), + r"\Asimple.{3}\z" + ); + + assert_eq!( + glob_to_regex("escape.", GlobMatchType::Whole)?.as_str(), + r"\Aescape\.\z" + ); + + assert!(glob_to_regex("simple", GlobMatchType::Whole)?.is_match("simple")); + assert!(!glob_to_regex("simple", GlobMatchType::Whole)?.is_match("simples")); + assert!(glob_to_regex("simple*", GlobMatchType::Whole)?.is_match("simples")); + assert!(glob_to_regex("simple?", GlobMatchType::Whole)?.is_match("simples")); + assert!(glob_to_regex("simple*", GlobMatchType::Whole)?.is_match("simple")); + + assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("some simple.")); + assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("simple")); + assert!(!glob_to_regex("simple", GlobMatchType::Word)?.is_match("simples")); + + assert!(glob_to_regex("@user:foo", GlobMatchType::Word)?.is_match("Some @user:foo test")); + assert!(glob_to_regex("@user:foo", GlobMatchType::Word)?.is_match("@user:foo")); + + Ok(()) +} diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 93c4e69d42..fffb8419c6 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -1,4 +1,4 @@ -from typing import Any, Collection, Dict, Mapping, Sequence, Tuple, Union +from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union from synapse.types import JsonDict @@ -35,3 +35,20 @@ class FilteredPushRules: def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def get_base_rule_ids() -> Collection[str]: ... + +class PushRuleEvaluator: + def __init__( + self, + flattened_keys: Mapping[str, str], + room_member_count: int, + sender_power_level: Optional[int], + notification_power_levels: Mapping[str, int], + relations: Mapping[str, Set[Tuple[str, str]]], + relation_match_enabled: bool, + ): ... + def run( + self, + push_rules: FilteredPushRules, + user_id: Optional[str], + display_name: Optional[str], + ) -> Collection[dict]: ... diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 32313e3bcf..60f3129005 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -17,6 +17,7 @@ import itertools import logging from typing import ( TYPE_CHECKING, + Any, Collection, Dict, Iterable, @@ -37,13 +38,11 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.state import StateFilter -from synapse.synapse_rust.push import FilteredPushRules, PushRule +from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state -from .push_rule_evaluator import PushRuleEvaluatorForEvent - if TYPE_CHECKING: from synapse.server import HomeServer @@ -290,11 +289,11 @@ class BulkPushRuleEvaluator: if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id - evaluator = PushRuleEvaluatorForEvent( - event, + evaluator = PushRuleEvaluator( + _flatten_dict(event), room_member_count, sender_power_level, - power_levels, + power_levels.get("notifications", {}), relations, self._relations_match_enabled, ) @@ -338,17 +337,10 @@ class BulkPushRuleEvaluator: # current user, it'll be added to the dict later. actions_by_user[uid] = [] - for rule, enabled in rules.rules(): - if not enabled: - continue - - matches = evaluator.check_conditions(rule.conditions, uid, display_name) - if matches: - actions = [x for x in rule.actions if x != "dont_notify"] - if actions and "notify" in actions: - # Push rules say we should notify the user of this event - actions_by_user[uid] = actions - break + actions = evaluator.run(rules, uid, display_name) + if "notify" in actions: + # Push rules say we should notify the user of this event + actions_by_user[uid] = actions # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist @@ -365,3 +357,21 @@ MemberMap = Dict[str, Optional[EventIdMembership]] Rule = Dict[str, dict] RulesByUser = Dict[str, List[Rule]] StateGroup = Union[object, int] + + +def _flatten_dict( + d: Union[EventBase, Mapping[str, Any]], + prefix: Optional[List[str]] = None, + result: Optional[Dict[str, str]] = None, +) -> Dict[str, str]: + if prefix is None: + prefix = [] + if result is None: + result = {} + for key, value in d.items(): + if isinstance(value, str): + result[".".join(prefix + [key])] = value.lower() + elif isinstance(value, Mapping): + _flatten_dict(value, prefix=(prefix + [key]), result=result) + + return result diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index e96fb45e9f..b048b03a74 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union from prometheus_client import Counter @@ -28,7 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.storage.databases.main.event_push_actions import HttpPushAction -from . import push_rule_evaluator, push_tools +from . import push_tools if TYPE_CHECKING: from synapse.server import HomeServer @@ -56,6 +56,39 @@ http_badges_failed_counter = Counter( ) +def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]: + """ + Converts a list of actions into a `tweaks` dict (which can then be passed to + the push gateway). + + This function ignores all actions other than `set_tweak` actions, and treats + absent `value`s as `True`, which agrees with the only spec-defined treatment + of absent `value`s (namely, for `highlight` tweaks). + + Args: + actions: list of actions + e.g. [ + {"set_tweak": "a", "value": "AAA"}, + {"set_tweak": "b", "value": "BBB"}, + {"set_tweak": "highlight"}, + "notify" + ] + + Returns: + dictionary of tweaks for those actions + e.g. {"a": "AAA", "b": "BBB", "highlight": True} + """ + tweaks = {} + for a in actions: + if not isinstance(a, dict): + continue + if "set_tweak" in a: + # value is allowed to be absent in which case the value assumed + # should be True. + tweaks[a["set_tweak"]] = a.get("value", True) + return tweaks + + class HttpPusher(Pusher): INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes MAX_BACKOFF_SEC = 60 * 60 @@ -281,7 +314,7 @@ class HttpPusher(Pusher): if "notify" not in push_action.actions: return True - tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions) + tweaks = tweaks_for_actions(push_action.actions) badge = await push_tools.get_badge_count( self.hs.get_datastores().main, self.user_id, diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py deleted file mode 100644 index f8176c5a42..0000000000 --- a/synapse/push/push_rule_evaluator.py +++ /dev/null @@ -1,361 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2017 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re -from typing import ( - Any, - Dict, - List, - Mapping, - Optional, - Pattern, - Sequence, - Set, - Tuple, - Union, -) - -from matrix_common.regex import glob_to_regex, to_word_pattern - -from synapse.events import EventBase -from synapse.types import UserID -from synapse.util.caches.lrucache import LruCache - -logger = logging.getLogger(__name__) - - -GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]") -IS_GLOB = re.compile(r"[\?\*\[\]]") -INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") - - -def _room_member_count(condition: Mapping[str, Any], room_member_count: int) -> bool: - return _test_ineq_condition(condition, room_member_count) - - -def _sender_notification_permission( - condition: Mapping[str, Any], - sender_power_level: Optional[int], - power_levels: Dict[str, Union[int, Dict[str, int]]], -) -> bool: - if sender_power_level is None: - return False - - notif_level_key = condition.get("key") - if notif_level_key is None: - return False - - notif_levels = power_levels.get("notifications", {}) - assert isinstance(notif_levels, dict) - room_notif_level = notif_levels.get(notif_level_key, 50) - - return sender_power_level >= room_notif_level - - -def _test_ineq_condition(condition: Mapping[str, Any], number: int) -> bool: - if "is" not in condition: - return False - m = INEQUALITY_EXPR.match(condition["is"]) - if not m: - return False - ineq = m.group(1) - rhs = m.group(2) - if not rhs.isdigit(): - return False - rhs_int = int(rhs) - - if ineq == "" or ineq == "==": - return number == rhs_int - elif ineq == "<": - return number < rhs_int - elif ineq == ">": - return number > rhs_int - elif ineq == ">=": - return number >= rhs_int - elif ineq == "<=": - return number <= rhs_int - else: - return False - - -def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]: - """ - Converts a list of actions into a `tweaks` dict (which can then be passed to - the push gateway). - - This function ignores all actions other than `set_tweak` actions, and treats - absent `value`s as `True`, which agrees with the only spec-defined treatment - of absent `value`s (namely, for `highlight` tweaks). - - Args: - actions: list of actions - e.g. [ - {"set_tweak": "a", "value": "AAA"}, - {"set_tweak": "b", "value": "BBB"}, - {"set_tweak": "highlight"}, - "notify" - ] - - Returns: - dictionary of tweaks for those actions - e.g. {"a": "AAA", "b": "BBB", "highlight": True} - """ - tweaks = {} - for a in actions: - if not isinstance(a, dict): - continue - if "set_tweak" in a: - # value is allowed to be absent in which case the value assumed - # should be True. - tweaks[a["set_tweak"]] = a.get("value", True) - return tweaks - - -class PushRuleEvaluatorForEvent: - def __init__( - self, - event: EventBase, - room_member_count: int, - sender_power_level: Optional[int], - power_levels: Dict[str, Union[int, Dict[str, int]]], - relations: Dict[str, Set[Tuple[str, str]]], - relations_match_enabled: bool, - ): - self._event = event - self._room_member_count = room_member_count - self._sender_power_level = sender_power_level - self._power_levels = power_levels - self._relations = relations - self._relations_match_enabled = relations_match_enabled - - # Maps strings of e.g. 'content.body' -> event["content"]["body"] - self._value_cache = _flatten_dict(event) - - # Maps cache keys to final values. - self._condition_cache: Dict[str, bool] = {} - - def check_conditions( - self, conditions: Sequence[Mapping], uid: str, display_name: Optional[str] - ) -> bool: - """ - Returns true if a user's conditions/user ID/display name match the event. - - Args: - conditions: The user's conditions to match. - uid: The user's MXID. - display_name: The display name. - - Returns: - True if all conditions match the event, False otherwise. - """ - for cond in conditions: - _cache_key = cond.get("_cache_key", None) - if _cache_key: - res = self._condition_cache.get(_cache_key, None) - if res is False: - return False - elif res is True: - continue - - res = self.matches(cond, uid, display_name) - if _cache_key: - self._condition_cache[_cache_key] = bool(res) - - if not res: - return False - - return True - - def matches( - self, condition: Mapping[str, Any], user_id: str, display_name: Optional[str] - ) -> bool: - """ - Returns true if a user's condition/user ID/display name match the event. - - Args: - condition: The user's condition to match. - uid: The user's MXID. - display_name: The display name, or None if there is not one. - - Returns: - True if the condition matches the event, False otherwise. - """ - if condition["kind"] == "event_match": - return self._event_match(condition, user_id) - elif condition["kind"] == "contains_display_name": - return self._contains_display_name(display_name) - elif condition["kind"] == "room_member_count": - return _room_member_count(condition, self._room_member_count) - elif condition["kind"] == "sender_notification_permission": - return _sender_notification_permission( - condition, self._sender_power_level, self._power_levels - ) - elif ( - condition["kind"] == "org.matrix.msc3772.relation_match" - and self._relations_match_enabled - ): - return self._relation_match(condition, user_id) - else: - # XXX This looks incorrect -- we have reached an unknown condition - # kind and are unconditionally returning that it matches. Note - # that it seems possible to provide a condition to the /pushrules - # endpoint with an unknown kind, see _rule_tuple_from_request_object. - return True - - def _event_match(self, condition: Mapping, user_id: str) -> bool: - """ - Check an "event_match" push rule condition. - - Args: - condition: The "event_match" push rule condition to match. - user_id: The user's MXID. - - Returns: - True if the condition matches the event, False otherwise. - """ - pattern = condition.get("pattern", None) - - if not pattern: - pattern_type = condition.get("pattern_type", None) - if pattern_type == "user_id": - pattern = user_id - elif pattern_type == "user_localpart": - pattern = UserID.from_string(user_id).localpart - - if not pattern: - logger.warning("event_match condition with no pattern") - return False - - # XXX: optimisation: cache our pattern regexps - if condition["key"] == "content.body": - body = self._event.content.get("body", None) - if not body or not isinstance(body, str): - return False - - return _glob_matches(pattern, body, word_boundary=True) - else: - haystack = self._value_cache.get(condition["key"], None) - if haystack is None: - return False - - return _glob_matches(pattern, haystack) - - def _contains_display_name(self, display_name: Optional[str]) -> bool: - """ - Check an "event_match" push rule condition. - - Args: - display_name: The display name, or None if there is not one. - - Returns: - True if the display name is found in the event body, False otherwise. - """ - if not display_name: - return False - - body = self._event.content.get("body", None) - if not body or not isinstance(body, str): - return False - - # Similar to _glob_matches, but do not treat display_name as a glob. - r = regex_cache.get((display_name, False, True), None) - if not r: - r1 = re.escape(display_name) - r1 = to_word_pattern(r1) - r = re.compile(r1, flags=re.IGNORECASE) - regex_cache[(display_name, False, True)] = r - - return bool(r.search(body)) - - def _relation_match(self, condition: Mapping, user_id: str) -> bool: - """ - Check an "relation_match" push rule condition. - - Args: - condition: The "event_match" push rule condition to match. - user_id: The user's MXID. - - Returns: - True if the condition matches the event, False otherwise. - """ - rel_type = condition.get("rel_type") - if not rel_type: - logger.warning("relation_match condition missing rel_type") - return False - - sender_pattern = condition.get("sender") - if sender_pattern is None: - sender_type = condition.get("sender_type") - if sender_type == "user_id": - sender_pattern = user_id - type_pattern = condition.get("type") - - # If any other relations matches, return True. - for sender, event_type in self._relations.get(rel_type, ()): - if sender_pattern and not _glob_matches(sender_pattern, sender): - continue - if type_pattern and not _glob_matches(type_pattern, event_type): - continue - # All values must have matched. - return True - - # No relations matched. - return False - - -# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache( - 50000, "regex_push_cache" -) - - -def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: - """Tests if value matches glob. - - Args: - glob - value: String to test against glob. - word_boundary: Whether to match against word boundaries or entire - string. Defaults to False. - """ - - try: - r = regex_cache.get((glob, True, word_boundary), None) - if not r: - r = glob_to_regex(glob, word_boundary=word_boundary) - regex_cache[(glob, True, word_boundary)] = r - return bool(r.search(value)) - except re.error: - logger.warning("Failed to parse glob to regex: %r", glob) - return False - - -def _flatten_dict( - d: Union[EventBase, Mapping[str, Any]], - prefix: Optional[List[str]] = None, - result: Optional[Dict[str, str]] = None, -) -> Dict[str, str]: - if prefix is None: - prefix = [] - if result is None: - result = {} - for key, value in d.items(): - if isinstance(value, str): - result[".".join(prefix + [key])] = value.lower() - elif isinstance(value, Mapping): - _flatten_dict(value, prefix=(prefix + [key]), result=result) - - return result diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 718f489577..b8308cbc05 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -23,11 +23,12 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.appservice import ApplicationService from synapse.events import FrozenEvent -from synapse.push import push_rule_evaluator -from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent +from synapse.push.bulk_push_rule_evaluator import _flatten_dict +from synapse.push.httppusher import tweaks_for_actions from synapse.rest.client import login, register, room from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex +from synapse.synapse_rust.push import PushRuleEvaluator from synapse.types import JsonDict from synapse.util import Clock @@ -41,7 +42,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): content: JsonDict, relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None, relations_match_enabled: bool = False, - ) -> PushRuleEvaluatorForEvent: + ) -> PushRuleEvaluator: event = FrozenEvent( { "event_id": "$event_id", @@ -56,12 +57,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count = 0 sender_power_level = 0 power_levels: Dict[str, Union[int, Dict[str, int]]] = {} - return PushRuleEvaluatorForEvent( - event, + return PushRuleEvaluator( + _flatten_dict(event), room_member_count, sender_power_level, - power_levels, - relations or set(), + power_levels.get("notifications", {}), + relations or {}, relations_match_enabled, ) @@ -293,7 +294,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): ] self.assertEqual( - push_rule_evaluator.tweaks_for_actions(actions), + tweaks_for_actions(actions), {"sound": "default", "highlight": True}, ) @@ -304,9 +305,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): evaluator = self._get_evaluator( {}, {"m.annotation": {("@user:test", "m.reaction")}} ) - condition = {"kind": "relation_match"} - # Oddly, an unknown condition always matches. - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) # A push rule evaluator with the experimental rule enabled. evaluator = self._get_evaluator( -- cgit 1.5.1 From 285b9e9b6c3558718e7d4f513062e277948ac35d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 30 Sep 2022 14:27:00 +0100 Subject: Speed up calculating push actions in large rooms (#13973) We move the expensive check of visibility to after calculating push actions, avoiding the expensive check for users who won't get pushed anyway. I think this should have a big impact on rooms with large numbers of local users that have pushed disabled. --- changelog.d/13973.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 25 ++++++---- tests/push/test_push_rule_evaluator.py | 82 +++++++++++++++++++++++++++++++- 3 files changed, 96 insertions(+), 12 deletions(-) create mode 100644 changelog.d/13973.misc (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/changelog.d/13973.misc b/changelog.d/13973.misc new file mode 100644 index 0000000000..58150a2b35 --- /dev/null +++ b/changelog.d/13973.misc @@ -0,0 +1 @@ +Speed up calculating push actions in large rooms. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 60f3129005..7bfe380543 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -303,20 +303,10 @@ class BulkPushRuleEvaluator: event.room_id, users ) - # This is a check for the case where user joins a room without being - # allowed to see history, and then the server receives a delayed event - # from before the user joined, which they should not be pushed for - uids_with_visibility = await filter_event_for_clients_with_state( - self.store, users, event, context - ) - for uid, rules in rules_by_user.items(): if event.sender == uid: continue - if uid not in uids_with_visibility: - continue - display_name = None profile = profiles.get(uid) if profile: @@ -342,6 +332,21 @@ class BulkPushRuleEvaluator: # Push rules say we should notify the user of this event actions_by_user[uid] = actions + # This is a check for the case where user joins a room without being + # allowed to see history, and then the server receives a delayed event + # from before the user joined, which they should not be pushed for + # + # We do this *after* calculating the push actions as a) its unlikely + # that we'll filter anyone out and b) for large rooms its likely that + # most users will have push disabled and so the set of users to check is + # much smaller. + uids_with_visibility = await filter_event_for_clients_with_state( + self.store, actions_by_user.keys(), event, context + ) + + for user_id in set(actions_by_user).difference(uids_with_visibility): + actions_by_user.pop(user_id, None) + # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index b8308cbc05..8804f0e0d3 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -19,17 +19,18 @@ import frozendict from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.room_versions import RoomVersions from synapse.appservice import ApplicationService from synapse.events import FrozenEvent from synapse.push.bulk_push_rule_evaluator import _flatten_dict from synapse.push.httppusher import tweaks_for_actions +from synapse.rest import admin from synapse.rest.client import login, register, room from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.synapse_rust.push import PushRuleEvaluator -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import Clock from tests import unittest @@ -437,3 +438,80 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): ) self.assertEqual(len(users_with_push_actions), 0) + + +class BulkPushRuleEvaluatorTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.main_store = homeserver.get_datastores().main + + self.user_id1 = self.register_user("user1", "password") + self.tok1 = self.login(self.user_id1, "password") + self.user_id2 = self.register_user("user2", "password") + self.tok2 = self.login(self.user_id2, "password") + + self.room_id = self.helper.create_room_as(tok=self.tok1) + + # We want to test history visibility works correctly. + self.helper.send_state( + self.room_id, + EventTypes.RoomHistoryVisibility, + {"history_visibility": HistoryVisibility.JOINED}, + tok=self.tok1, + ) + + def get_notif_count(self, user_id: str) -> int: + return self.get_success( + self.main_store.db_pool.simple_select_one_onecol( + table="event_push_actions", + keyvalues={"user_id": user_id}, + retcol="COALESCE(SUM(notif), 0)", + desc="get_staging_notif_count", + ) + ) + + def test_plain_message(self) -> None: + """Test that sending a normal message in a room will trigger a + notification + """ + + # Have user2 join the room and cle + self.helper.join(self.room_id, self.user_id2, tok=self.tok2) + + # They start off with no notifications, but get them when messages are + # sent. + self.assertEqual(self.get_notif_count(self.user_id2), 0) + + user1 = UserID.from_string(self.user_id1) + self.create_and_send_event(self.room_id, user1) + + self.assertEqual(self.get_notif_count(self.user_id2), 1) + + def test_delayed_message(self) -> None: + """Test that a delayed message that was from before a user joined + doesn't cause a notification for the joined user. + """ + user1 = UserID.from_string(self.user_id1) + + # Send a message before user2 joins + event_id1 = self.create_and_send_event(self.room_id, user1) + + # Have user2 join the room + self.helper.join(self.room_id, self.user_id2, tok=self.tok2) + + # They start off with no notifications + self.assertEqual(self.get_notif_count(self.user_id2), 0) + + # Send another message that references the event before the join to + # simulate a "delayed" event + self.create_and_send_event(self.room_id, user1, prev_event_ids=[event_id1]) + + # user2 should not be notified about it, because they can't see it. + self.assertEqual(self.get_notif_count(self.user_id2), 0) -- cgit 1.5.1 From 535f8c8f7d64d4058500a5988278fd3026645164 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 30 Sep 2022 17:40:33 +0100 Subject: Skip filtering during push if there are no push actions (#13992) --- changelog.d/13992.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 5 +++++ synapse/visibility.py | 4 ++++ tests/rest/client/test_rooms.py | 4 ++-- 4 files changed, 12 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13992.misc (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/changelog.d/13992.misc b/changelog.d/13992.misc new file mode 100644 index 0000000000..58150a2b35 --- /dev/null +++ b/changelog.d/13992.misc @@ -0,0 +1 @@ +Speed up calculating push actions in large rooms. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7bfe380543..4270438918 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -332,6 +332,11 @@ class BulkPushRuleEvaluator: # Push rules say we should notify the user of this event actions_by_user[uid] = actions + # If there aren't any actions then we can skip the rest of the + # processing. + if not actions_by_user: + return + # This is a check for the case where user joins a room without being # allowed to see history, and then the server receives a delayed event # from before the user joined, which they should not be pushed for diff --git a/synapse/visibility.py b/synapse/visibility.py index c810a05907..c4048d2477 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -162,6 +162,10 @@ async def filter_event_for_clients_with_state( if event.internal_metadata.is_soft_failed(): return [] + # Fast path if we don't have any user IDs to check. + if not user_ids: + return () + # Make a set for all user IDs that haven't been filtered out by a check. allowed_user_ids = set(user_ids) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index e281aef779..7f8cf4fab0 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -710,7 +710,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(35, channel.resource_usage.db_txn_count) + self.assertEqual(34, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -723,7 +723,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(38, channel.resource_usage.db_txn_count) + self.assertEqual(37, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id -- cgit 1.5.1 From b4ec4f5e71a87d5bdc840a4220dfd9a34c54c847 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 09:47:04 -0400 Subject: Track notification counts per thread (implement MSC3773). (#13776) When retrieving counts of notifications segment the results based on the thread ID, but choose whether to return them as individual threads or as a single summed field by letting the client opt-in via a sync flag. The summarization code is also updated to be per thread, instead of per room. --- changelog.d/13776.feature | 1 + synapse/api/constants.py | 3 + synapse/api/filtering.py | 10 ++ synapse/config/experimental.py | 2 + synapse/handlers/sync.py | 40 ++++- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/push/push_tools.py | 9 +- synapse/rest/client/sync.py | 4 + synapse/rest/client/versions.py | 3 +- synapse/storage/database.py | 2 +- .../storage/databases/main/event_push_actions.py | 188 +++++++++++++-------- synapse/storage/schema/__init__.py | 6 +- .../delta/73/06thread_notifications_backfill.sql | 29 ++++ .../07thread_notifications_not_null.sql.postgres | 19 +++ .../73/07thread_notifications_not_null.sql.sqlite | 101 +++++++++++ tests/replication/slave/storage/test_events.py | 17 +- tests/storage/test_event_push_actions.py | 169 +++++++++++++++++- 17 files changed, 514 insertions(+), 93 deletions(-) create mode 100644 changelog.d/13776.feature create mode 100644 synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql create mode 100644 synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres create mode 100644 synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/changelog.d/13776.feature b/changelog.d/13776.feature new file mode 100644 index 0000000000..22bce125ce --- /dev/null +++ b/changelog.d/13776.feature @@ -0,0 +1 @@ +Experimental support for thread-specific notifications ([MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index c031903b1a..44c5ffc6a5 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -31,6 +31,9 @@ MAX_ALIAS_LENGTH = 255 # the maximum length for a user id is 255 characters MAX_USERID_LENGTH = 255 +# Constant value used for the pseudo-thread which is the main timeline. +MAIN_TIMELINE: Final = "main" + class Membership: diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index f7f46f8d80..c6e44dcf82 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -84,6 +84,7 @@ ROOM_EVENT_FILTER_SCHEMA = { "contains_url": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"}, + "org.matrix.msc3773.unread_thread_notifications": {"type": "boolean"}, # Include or exclude events with the provided labels. # cf https://github.com/matrix-org/matrix-doc/pull/2326 "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, @@ -240,6 +241,9 @@ class FilterCollection: def include_redundant_members(self) -> bool: return self._room_state_filter.include_redundant_members + def unread_thread_notifications(self) -> bool: + return self._room_timeline_filter.unread_thread_notifications + async def filter_presence( self, events: Iterable[UserPresenceState] ) -> List[UserPresenceState]: @@ -304,6 +308,12 @@ class Filter: self.include_redundant_members = filter_json.get( "include_redundant_members", False ) + if hs.config.experimental.msc3773_enabled: + self.unread_thread_notifications: bool = filter_json.get( + "org.matrix.msc3773.unread_thread_notifications", False + ) + else: + self.unread_thread_notifications = False self.types = filter_json.get("types", None) self.not_types = filter_json.get("not_types", []) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 83695f24d9..6503ce6e34 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -99,6 +99,8 @@ class ExperimentalConfig(Config): self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False) # MSC3772: A push rule for mutual relations. self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) + # MSC3773: Thread notifications + self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) # MSC3715: dir param on /relations. self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4abb9b6127..329e89c604 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -40,7 +40,7 @@ from synapse.handlers.relations import BundledAggregations from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user -from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.event_push_actions import RoomNotifCounts from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -128,6 +128,7 @@ class JoinedSyncResult: ephemeral: List[JsonDict] account_data: List[JsonDict] unread_notifications: JsonDict + unread_thread_notifications: JsonDict summary: Optional[JsonDict] unread_count: int @@ -278,6 +279,8 @@ class SyncHandler: self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync + self._msc3773_enabled = hs.config.experimental.msc3773_enabled + async def wait_for_sync_for_user( self, requester: Requester, @@ -1288,7 +1291,7 @@ class SyncHandler: async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig - ) -> NotifCounts: + ) -> RoomNotifCounts: with Measure(self.clock, "unread_notifs_for_room_id"): return await self.store.get_unread_event_push_actions_by_room_for_user( @@ -2353,6 +2356,7 @@ class SyncHandler: ephemeral=ephemeral, account_data=account_data_events, unread_notifications=unread_notifications, + unread_thread_notifications={}, summary=summary, unread_count=0, ) @@ -2360,10 +2364,36 @@ class SyncHandler: if room_sync or always_include: notifs = await self.unread_notifs_for_room_id(room_id, sync_config) - unread_notifications["notification_count"] = notifs.notify_count - unread_notifications["highlight_count"] = notifs.highlight_count + # Notifications for the main timeline. + notify_count = notifs.main_timeline.notify_count + highlight_count = notifs.main_timeline.highlight_count + unread_count = notifs.main_timeline.unread_count - room_sync.unread_count = notifs.unread_count + # Check the sync configuration. + if ( + self._msc3773_enabled + and sync_config.filter_collection.unread_thread_notifications() + ): + # And add info for each thread. + room_sync.unread_thread_notifications = { + thread_id: { + "notification_count": thread_notifs.notify_count, + "highlight_count": thread_notifs.highlight_count, + } + for thread_id, thread_notifs in notifs.threads.items() + if thread_id is not None + } + + else: + # Combine the unread counts for all threads and main timeline. + for thread_notifs in notifs.threads.values(): + notify_count += thread_notifs.notify_count + highlight_count += thread_notifs.highlight_count + unread_count += thread_notifs.unread_count + + unread_notifications["notification_count"] = notify_count + unread_notifications["highlight_count"] = highlight_count + room_sync.unread_count = unread_count sync_result_builder.joined.append(room_sync) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 4270438918..61d952742d 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -31,7 +31,7 @@ from typing import ( from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership, RelationTypes +from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext @@ -280,7 +280,7 @@ class BulkPushRuleEvaluator: # If the event does not have a relation, then cannot have any mutual # relations or thread ID. relations = {} - thread_id = "main" + thread_id = MAIN_TIMELINE if relation: relations = await self._get_mutual_relations( relation.parent_id, diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 658bf373b7..edeba27a45 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -39,7 +39,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - await concurrently_execute(get_room_unread_count, joins, 10) for notifs in room_notifs: - if notifs.notify_count == 0: + # Combine the counts from all the threads. + notify_count = notifs.main_timeline.notify_count + sum( + n.notify_count for n in notifs.threads.values() + ) + + if notify_count == 0: continue if group_by_room: @@ -47,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - badge += 1 else: # increment the badge count by the number of unread messages in the room - badge += notifs.notify_count + badge += notify_count return badge diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index c2989765ce..f1c23d68e5 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -509,6 +509,10 @@ class SyncRestServlet(RestServlet): ephemeral_events = room.ephemeral result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications + if room.unread_thread_notifications: + result[ + "org.matrix.msc3773.unread_thread_notifications" + ] = room.unread_thread_notifications result["summary"] = room.summary if self._msc2654_enabled: result["org.matrix.msc2654.unread_count"] = room.unread_count diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c95b0d6f19..280d306483 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -103,8 +103,9 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above - # Support for thread read receipts. + # Support for thread read receipts & notification counts. "org.matrix.msc3771": self.config.experimental.msc3771_enabled, + "org.matrix.msc3773": self.config.experimental.msc3773_enabled, # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, # Adds support for login token requests as per MSC3882 diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b4469eb964..7bb21f8f81 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -94,7 +94,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "event_search": "event_search_event_id_idx", "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx", "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx", - "event_push_summary": "event_push_summary_unique_index", + "event_push_summary": "event_push_summary_unique_index2", "receipts_linearized": "receipts_linearized_unique_index", "receipts_graph": "receipts_graph_unique_index", } diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index cdc9ee5a37..3210e9cca1 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -88,7 +88,7 @@ from typing import ( import attr -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -157,7 +157,7 @@ class UserPushAction(EmailPushAction): @attr.s(slots=True, auto_attribs=True) class NotifCounts: """ - The per-user, per-room count of notifications. Used by sync and push. + The per-user, per-room, per-thread count of notifications. Used by sync and push. """ notify_count: int = 0 @@ -165,6 +165,21 @@ class NotifCounts: highlight_count: int = 0 +@attr.s(slots=True, auto_attribs=True) +class RoomNotifCounts: + """ + The per-user, per-room count of notifications. Used by sync and push. + """ + + main_timeline: NotifCounts + # Map of thread ID to the notification counts. + threads: Dict[str, NotifCounts] + + def __len__(self) -> int: + # To properly account for the amount of space in any caches. + return len(self.threads) + 1 + + def _serialize_action( actions: Collection[Union[Mapping, str]], is_highlight: bool ) -> str: @@ -338,12 +353,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return result - @cached(tree=True, max_entries=5000) + @cached(tree=True, max_entries=5000, iterable=True) async def get_unread_event_push_actions_by_room_for_user( self, room_id: str, user_id: str, - ) -> NotifCounts: + ) -> RoomNotifCounts: """Get the notification count, the highlight count and the unread message count for a given user in a given room after their latest read receipt. @@ -356,8 +371,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: The user to retrieve the counts for. Returns - A NotifCounts object containing the notification count, the highlight count - and the unread message count. + A RoomNotifCounts object containing the notification count, the + highlight count and the unread message count for both the main timeline + and threads. """ return await self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", @@ -371,7 +387,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn: LoggingTransaction, room_id: str, user_id: str, - ) -> NotifCounts: + ) -> RoomNotifCounts: # Get the stream ordering of the user's latest receipt in the room. result = self.get_last_unthreaded_receipt_for_user_txn( txn, @@ -406,7 +422,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas room_id: str, user_id: str, receipt_stream_ordering: int, - ) -> NotifCounts: + ) -> RoomNotifCounts: """Get the number of unread messages for a user/room that have happened since the given stream ordering. @@ -418,12 +434,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas receipt in the room. If there are no receipts, the stream ordering of the user's join event. - Returns - A NotifCounts object containing the notification count, the highlight count - and the unread message count. + Returns: + A RoomNotifCounts object containing the notification count, the + highlight count and the unread message count for both the main timeline + and threads. """ - counts = NotifCounts() + main_counts = NotifCounts() + thread_counts: Dict[str, NotifCounts] = {} + + def _get_thread(thread_id: str) -> NotifCounts: + if thread_id == MAIN_TIMELINE: + return main_counts + return thread_counts.setdefault(thread_id, NotifCounts()) # First we pull the counts from the summary table. # @@ -440,52 +463,61 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # receipt). txn.execute( """ - SELECT stream_ordering, notif_count, COALESCE(unread_count, 0) + SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id FROM event_push_summary WHERE room_id = ? AND user_id = ? AND ( (last_receipt_stream_ordering IS NULL AND stream_ordering > ?) OR last_receipt_stream_ordering = ? - ) + ) AND (notif_count != 0 OR COALESCE(unread_count, 0) != 0) """, (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering), ) - row = txn.fetchone() - - summary_stream_ordering = 0 - if row: - summary_stream_ordering = row[0] - counts.notify_count += row[1] - counts.unread_count += row[2] + max_summary_stream_ordering = 0 + for summary_stream_ordering, notif_count, unread_count, thread_id in txn: + counts = _get_thread(thread_id) + counts.notify_count += notif_count + counts.unread_count += unread_count + + # Summaries will only be used if they have not been invalidated by + # a recent receipt; track the latest stream ordering or a valid summary. + # + # Note that since there's only one read receipt in the room per user, + # valid summaries are contiguous. + max_summary_stream_ordering = max( + summary_stream_ordering, max_summary_stream_ordering + ) # Next we need to count highlights, which aren't summarised sql = """ - SELECT COUNT(*) FROM event_push_actions + SELECT COUNT(*), thread_id FROM event_push_actions WHERE user_id = ? AND room_id = ? AND stream_ordering > ? AND highlight = 1 + GROUP BY thread_id """ txn.execute(sql, (user_id, room_id, receipt_stream_ordering)) - row = txn.fetchone() - if row: - counts.highlight_count += row[0] + for highlight_count, thread_id in txn: + _get_thread(thread_id).highlight_count += highlight_count # Finally we need to count push actions that aren't included in the # summary returned above. This might be due to recent events that haven't # been summarised yet or the summary is out of date due to a recent read # receipt. start_unread_stream_ordering = max( - receipt_stream_ordering, summary_stream_ordering + receipt_stream_ordering, max_summary_stream_ordering ) - notify_count, unread_count = self._get_notif_unread_count_for_user_room( + unread_counts = self._get_notif_unread_count_for_user_room( txn, room_id, user_id, start_unread_stream_ordering ) - counts.notify_count += notify_count - counts.unread_count += unread_count + for notif_count, unread_count, thread_id in unread_counts: + counts = _get_thread(thread_id) + counts.notify_count += notif_count + counts.unread_count += unread_count - return counts + return RoomNotifCounts(main_counts, thread_counts) def _get_notif_unread_count_for_user_room( self, @@ -494,7 +526,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: str, stream_ordering: int, max_stream_ordering: Optional[int] = None, - ) -> Tuple[int, int]: + ) -> List[Tuple[int, int, str]]: """Returns the notify and unread counts from `event_push_actions` for the given user/room in the given range. @@ -510,13 +542,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas If this is not given, then no maximum is applied. Return: - A tuple of the notif count and unread count in the given range. + A tuple of the notif count and unread count in the given range for + each thread. """ # If there have been no events in the room since the stream ordering, # there can't be any push actions either. if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering): - return 0, 0 + return [] clause = "" args = [user_id, room_id, stream_ordering] @@ -527,26 +560,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # If the max stream ordering is less than the min stream ordering, # then obviously there are zero push actions in that range. if max_stream_ordering <= stream_ordering: - return 0, 0 + return [] sql = f""" SELECT COUNT(CASE WHEN notif = 1 THEN 1 END), - COUNT(CASE WHEN unread = 1 THEN 1 END) - FROM event_push_actions ea - WHERE user_id = ? + COUNT(CASE WHEN unread = 1 THEN 1 END), + thread_id + FROM event_push_actions ea + WHERE user_id = ? AND room_id = ? AND ea.stream_ordering > ? {clause} + GROUP BY thread_id """ txn.execute(sql, args) - row = txn.fetchone() - - if row: - return cast(Tuple[int, int], row) - - return 0, 0 + return cast(List[Tuple[int, int, str]], txn.fetchall()) async def get_push_action_users_in_range( self, min_stream_ordering: int, max_stream_ordering: int @@ -1099,26 +1129,34 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Fetch the notification counts between the stream ordering of the # latest receipt and what was previously summarised. - notif_count, unread_count = self._get_notif_unread_count_for_user_room( + unread_counts = self._get_notif_unread_count_for_user_room( txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering ) - # Replace the previous summary with the new counts. - # - # TODO(threads): Upsert per-thread instead of setting them all to main. - self.db_pool.simple_upsert_txn( + # First mark the summary for all threads in the room as cleared. + self.db_pool.simple_update_txn( txn, table="event_push_summary", - keyvalues={"room_id": room_id, "user_id": user_id}, - values={ - "notif_count": notif_count, - "unread_count": unread_count, + keyvalues={"user_id": user_id, "room_id": room_id}, + updatevalues={ + "notif_count": 0, + "unread_count": 0, "stream_ordering": old_rotate_stream_ordering, "last_receipt_stream_ordering": stream_ordering, - "thread_id": "main", }, ) + # Then any updated threads get their notification count and unread + # count updated. + self.db_pool.simple_update_many_txn( + txn, + table="event_push_summary", + key_names=("room_id", "user_id", "thread_id"), + key_values=[(room_id, user_id, row[2]) for row in unread_counts], + value_names=("notif_count", "unread_count"), + value_values=[(row[0], row[1]) for row in unread_counts], + ) + # We always update `event_push_summary_last_receipt_stream_id` to # ensure that we don't rescan the same receipts for remote users. @@ -1204,23 +1242,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Calculate the new counts that should be upserted into event_push_summary sql = """ - SELECT user_id, room_id, + SELECT user_id, room_id, thread_id, coalesce(old.%s, 0) + upd.cnt, upd.stream_ordering FROM ( - SELECT user_id, room_id, count(*) as cnt, + SELECT user_id, room_id, thread_id, count(*) as cnt, max(ea.stream_ordering) as stream_ordering FROM event_push_actions AS ea - LEFT JOIN event_push_summary AS old USING (user_id, room_id) + LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id) WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ? AND ( old.last_receipt_stream_ordering IS NULL OR old.last_receipt_stream_ordering < ea.stream_ordering ) AND %s = 1 - GROUP BY user_id, room_id + GROUP BY user_id, room_id, thread_id ) AS upd - LEFT JOIN event_push_summary AS old USING (user_id, room_id) + LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id) """ # First get the count of unread messages. @@ -1234,11 +1272,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # object because we might not have the same amount of rows in each of them. To do # this, we use a dict indexed on the user ID and room ID to make it easier to # populate. - summaries: Dict[Tuple[str, str], _EventPushSummary] = {} + summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {} for row in txn: - summaries[(row[0], row[1])] = _EventPushSummary( - unread_count=row[2], - stream_ordering=row[3], + summaries[(row[0], row[1], row[2])] = _EventPushSummary( + unread_count=row[3], + stream_ordering=row[4], notif_count=0, ) @@ -1249,34 +1287,35 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) for row in txn: - if (row[0], row[1]) in summaries: - summaries[(row[0], row[1])].notif_count = row[2] + if (row[0], row[1], row[2]) in summaries: + summaries[(row[0], row[1], row[2])].notif_count = row[3] else: # Because the rules on notifying are different than the rules on marking # a message unread, we might end up with messages that notify but aren't # marked unread, so we might not have a summary for this (user, room) # tuple to complete. - summaries[(row[0], row[1])] = _EventPushSummary( + summaries[(row[0], row[1], row[2])] = _EventPushSummary( unread_count=0, - stream_ordering=row[3], - notif_count=row[2], + stream_ordering=row[4], + notif_count=row[3], ) logger.info("Rotating notifications, handling %d rows", len(summaries)) - # TODO(threads): Update on a per-thread basis. self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", - key_names=("user_id", "room_id"), - key_values=[(user_id, room_id) for user_id, room_id in summaries], - value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"), + key_names=("user_id", "room_id", "thread_id"), + key_values=[ + (user_id, room_id, thread_id) + for user_id, room_id, thread_id in summaries + ], + value_names=("notif_count", "unread_count", "stream_ordering"), value_values=[ ( summary.notif_count, summary.unread_count, summary.stream_ordering, - "main", ) for summary in summaries.values() ], @@ -1288,7 +1327,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) async def _remove_old_push_actions_that_have_rotated(self) -> None: - """Clear out old push actions that have been summarised.""" + """ + Clear out old push actions that have been summarised (and are older than + 1 day ago). + """ # We want to clear out anything that is older than a day that *has* already # been rotated. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 4a5c947699..19dbf2da7f 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -90,9 +90,9 @@ Changes in SCHEMA_VERSION = 73; SCHEMA_COMPAT_VERSION = ( - # The groups tables are no longer accessible, so synapses with SCHEMA_VERSION < 72 - # could break. - 72 + # The threads_id column must exist for event_push_actions, event_push_summary, + # receipts_linearized, and receipts_graph. + 73 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql b/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql new file mode 100644 index 0000000000..0ffde9bbeb --- /dev/null +++ b/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql @@ -0,0 +1,29 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Forces the background updates from 06thread_notifications.sql to run in the +-- foreground as code will now require those to be "done". + +DELETE FROM background_updates WHERE update_name = 'event_push_backfill_thread_id'; + +-- Overwrite any null thread_id columns. +UPDATE event_push_actions_staging SET thread_id = 'main' WHERE thread_id IS NULL; +UPDATE event_push_actions SET thread_id = 'main' WHERE thread_id IS NULL; +UPDATE event_push_summary SET thread_id = 'main' WHERE thread_id IS NULL; + +-- Do not run the event_push_summary_unique_index job if it is pending; the +-- thread_id field will be made required. +DELETE FROM background_updates WHERE update_name = 'event_push_summary_unique_index'; +DROP INDEX IF EXISTS event_push_summary_unique_index; diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres new file mode 100644 index 0000000000..33674f8c62 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres @@ -0,0 +1,19 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The columns can now be made non-nullable. +ALTER TABLE event_push_actions_staging ALTER COLUMN thread_id SET NOT NULL; +ALTER TABLE event_push_actions ALTER COLUMN thread_id SET NOT NULL; +ALTER TABLE event_push_summary ALTER COLUMN thread_id SET NOT NULL; diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite new file mode 100644 index 0000000000..5322ad77a4 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite @@ -0,0 +1,101 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- SQLite doesn't support modifying columns to an existing table, so it must +-- be recreated. + +-- Create the new tables. +CREATE TABLE event_push_actions_staging_new ( + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + actions TEXT NOT NULL, + notif SMALLINT NOT NULL, + highlight SMALLINT NOT NULL, + unread SMALLINT, + thread_id TEXT NOT NULL, + inserted_ts BIGINT +); + +CREATE TABLE event_push_actions_new ( + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + profile_tag VARCHAR(32), + actions TEXT NOT NULL, + topological_ordering BIGINT, + stream_ordering BIGINT, + notif SMALLINT, + highlight SMALLINT, + unread SMALLINT, + thread_id TEXT NOT NULL, + CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) +); + +CREATE TABLE event_push_summary_new ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notif_count BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL, + unread_count BIGINT, + last_receipt_stream_ordering BIGINT, + thread_id TEXT NOT NULL +); + +-- Swap the indexes. +DROP INDEX IF EXISTS event_push_actions_staging_id; +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging_new(event_id); + +DROP INDEX IF EXISTS event_push_actions_room_id_user_id; +DROP INDEX IF EXISTS event_push_actions_rm_tokens; +DROP INDEX IF EXISTS event_push_actions_stream_ordering; +DROP INDEX IF EXISTS event_push_actions_u_highlight; +DROP INDEX IF EXISTS event_push_actions_highlights_index; +CREATE INDEX event_push_actions_room_id_user_id on event_push_actions_new(room_id, user_id); +CREATE INDEX event_push_actions_rm_tokens on event_push_actions_new( user_id, room_id, topological_ordering, stream_ordering ); +CREATE INDEX event_push_actions_stream_ordering on event_push_actions_new( stream_ordering, user_id ); +CREATE INDEX event_push_actions_u_highlight ON event_push_actions_new (user_id, stream_ordering); +CREATE INDEX event_push_actions_highlights_index ON event_push_actions_new (user_id, room_id, topological_ordering, stream_ordering); + +-- Copy the data. +INSERT INTO event_push_actions_staging_new (event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts) + SELECT event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts + FROM event_push_actions_staging; + +INSERT INTO event_push_actions_new (room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id) + SELECT room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id + FROM event_push_actions; + +INSERT INTO event_push_summary_new (user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id) + SELECT user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id + FROM event_push_summary; + +-- Drop the old tables. +DROP TABLE event_push_actions_staging; +DROP TABLE event_push_actions; +DROP TABLE event_push_summary; + +-- Rename the tables. +ALTER TABLE event_push_actions_staging_new RENAME TO event_push_actions_staging; +ALTER TABLE event_push_actions_new RENAME TO event_push_actions; +ALTER TABLE event_push_summary_new RENAME TO event_push_summary; + +-- Re-run background updates from 72/02event_push_actions_index.sql and +-- 72/06thread_notifications.sql. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7307, 'event_push_summary_unique_index2', '{}') + ON CONFLICT (update_name) DO NOTHING; +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7307, 'event_push_actions_stream_highlight_index', '{}') + ON CONFLICT (update_name) DO NOTHING; diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index efd92793c0..d42e36cdf1 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -22,7 +22,10 @@ from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.event_push_actions import ( + NotifCounts, + RoomNotifCounts, +) from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.types import PersistedEventPosition @@ -178,7 +181,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=0, unread_count=0, notify_count=0), + RoomNotifCounts( + NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {} + ), ) self.persist( @@ -191,7 +196,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=0, unread_count=0, notify_count=1), + RoomNotifCounts( + NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {} + ), ) self.persist( @@ -206,7 +213,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=1, unread_count=0, notify_count=2), + RoomNotifCounts( + NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {} + ), ) def test_get_rooms_for_user_with_stream_ordering(self): diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 473c965e19..89f986ac34 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Optional, Tuple from twisted.test.proto_helpers import MemoryReactor @@ -20,6 +20,7 @@ from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.types import JsonDict from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -133,13 +134,14 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) ) self.assertEqual( - counts, + counts.main_timeline, NotifCounts( notify_count=noitf_count, unread_count=0, highlight_count=highlight_count, ), ) + self.assertEqual(counts.threads, {}) def _create_event(highlight: bool = False) -> str: result = self.helper.send_event( @@ -186,6 +188,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _assert_counts(0, 0) _create_event() + _assert_counts(1, 0) _rotate() _assert_counts(1, 0) @@ -236,6 +239,168 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _rotate() _assert_counts(0, 0) + def test_count_aggregation_threads(self) -> None: + """ + This is essentially the same test as test_count_aggregation, but adds + events to the main timeline and to a thread. + """ + + user_id, token, _, other_token, room_id = self._create_users_and_room() + thread_id: str + + last_event_id: str + + def _assert_counts( + noitf_count: int, + highlight_count: int, + thread_notif_count: int, + thread_highlight_count: int, + ) -> None: + counts = self.get_success( + self.store.db_pool.runInteraction( + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, + room_id, + user_id, + ) + ) + self.assertEqual( + counts.main_timeline, + NotifCounts( + notify_count=noitf_count, + unread_count=0, + highlight_count=highlight_count, + ), + ) + if thread_notif_count or thread_highlight_count: + self.assertEqual( + counts.threads, + { + thread_id: NotifCounts( + notify_count=thread_notif_count, + unread_count=0, + highlight_count=thread_highlight_count, + ), + }, + ) + else: + self.assertEqual(counts.threads, {}) + + def _create_event( + highlight: bool = False, thread_id: Optional[str] = None + ) -> str: + content: JsonDict = { + "msgtype": "m.text", + "body": user_id if highlight else "msg", + } + if thread_id: + content["m.relates_to"] = { + "rel_type": "m.thread", + "event_id": thread_id, + } + + result = self.helper.send_event( + room_id, + type="m.room.message", + content=content, + tok=other_token, + ) + nonlocal last_event_id + last_event_id = result["event_id"] + return last_event_id + + def _rotate() -> None: + self.get_success(self.store._rotate_notifs()) + + def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: + self.get_success( + self.store.insert_receipt( + room_id, + "m.read", + user_id=user_id, + event_ids=[event_id], + thread_id=thread_id, + data={}, + ) + ) + + _assert_counts(0, 0, 0, 0) + thread_id = _create_event() + _assert_counts(1, 0, 0, 0) + _rotate() + _assert_counts(1, 0, 0, 0) + + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + _create_event() + _assert_counts(2, 0, 1, 0) + _rotate() + _assert_counts(2, 0, 1, 0) + + event_id = _create_event(thread_id=thread_id) + _assert_counts(2, 0, 2, 0) + _rotate() + _assert_counts(2, 0, 2, 0) + + _create_event() + _create_event(thread_id=thread_id) + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event() + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + # Delete old event push actions, this should not affect the (summarised) count. + self.get_success(self.store._remove_old_push_actions_that_have_rotated()) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _assert_counts(1, 1, 0, 0) + _rotate() + _assert_counts(1, 1, 0, 0) + + event_id = _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _rotate() + _assert_counts(1, 1, 1, 1) + + # Check that adding another notification and rotating after highlight + # works. + _create_event() + _rotate() + _assert_counts(2, 1, 1, 1) + + _create_event(thread_id=thread_id) + _rotate() + _assert_counts(2, 1, 2, 1) + + # Check that sending read receipts at different points results in the + # right counts. + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + _rotate() + _assert_counts(0, 0, 0, 0) + def test_find_first_stream_ordering_after_ts(self) -> None: def add_event(so: int, ts: int) -> None: self.get_success( -- cgit 1.5.1 From 2b6d41ebd685fb546e52acdbcb0024dfcf5a5db1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 11:36:16 -0400 Subject: Recursively fetch the thread for receipts & notifications. (#13824) Consider an event to be part of a thread if you can follow a chain of relations up to a thread root. Part of MSC3773 & MSC3771. --- changelog.d/13824.feature | 1 + synapse/push/bulk_push_rule_evaluator.py | 5 ++ synapse/rest/client/receipts.py | 22 +++++- synapse/storage/databases/main/relations.py | 36 ++++++++++ tests/storage/test_event_push_actions.py | 100 ++++++++++++++++++++++++++++ 5 files changed, 162 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13824.feature (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/changelog.d/13824.feature b/changelog.d/13824.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13824.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 61d952742d..f8c4dd74f0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -286,8 +286,13 @@ class BulkPushRuleEvaluator: relation.parent_id, itertools.chain(*(r.rules() for r in rules_by_user.values())), ) + # Recursively attempt to find the thread this event relates to. if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id + else: + # Since the event has not yet been persisted we check whether + # the parent is part of a thread. + thread_id = await self.store.get_thread_id(relation.parent_id) or "main" evaluator = PushRuleEvaluator( _flatten_dict(event), diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index f3ff156abe..287dfdd69e 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -16,7 +16,7 @@ import logging from typing import TYPE_CHECKING, Tuple from synapse.api.constants import ReceiptTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -43,6 +43,7 @@ class ReceiptRestServlet(RestServlet): self.receipts_handler = hs.get_receipts_handler() self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() + self._main_store = hs.get_datastores().main self._known_receipt_types = { ReceiptTypes.READ, @@ -71,7 +72,24 @@ class ReceiptRestServlet(RestServlet): thread_id = body.get("thread_id") if not thread_id or not isinstance(thread_id, str): raise SynapseError( - 400, "thread_id field must be a non-empty string" + 400, + "thread_id field must be a non-empty string", + Codes.INVALID_PARAM, + ) + + if receipt_type == ReceiptTypes.FULLY_READ: + raise SynapseError( + 400, + f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.", + Codes.INVALID_PARAM, + ) + + # Ensure the event ID roughly correlates to the thread ID. + if thread_id != await self._main_store.get_thread_id(event_id): + raise SynapseError( + 400, + f"event_id {event_id} is not related to thread {thread_id}", + Codes.INVALID_PARAM, ) await self.presence_handler.bump_presence_active_time(requester.user) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 898947af95..154385b1e8 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -832,6 +832,42 @@ class RelationsWorkerStore(SQLBaseStore): "get_event_relations", _get_event_relations ) + @cached() + async def get_thread_id(self, event_id: str) -> Optional[str]: + """ + Get the thread ID for an event. This considers multi-level relations, + e.g. an annotation to an event which is part of a thread. + + Args: + event_id: The event ID to fetch the thread ID for. + + Returns: + The event ID of the root event in the thread, if this event is part + of a thread. None, otherwise. + """ + # Since event relations form a tree, we should only ever find 0 or 1 + # results from the below query. + sql = """ + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; + """ + + def _get_thread_id(txn: LoggingTransaction) -> Optional[str]: + txn.execute(sql, (event_id,)) + # TODO Should we ensure there's only a single result here? + row = txn.fetchone() + if row: + return row[0] + return None + + return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 6fa0cafb75..886585e9f2 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -588,6 +588,106 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _rotate() _assert_counts(0, 0, 0, 0) + def test_recursive_thread(self) -> None: + """ + Events related to events in a thread should still be considered part of + that thread. + """ + + # Create a user to receive notifications and send receipts. + user_id = self.register_user("user1235", "pass") + token = self.login("user1235", "pass") + + # And another users to send events. + other_id = self.register_user("other", "pass") + other_token = self.login("other", "pass") + + # Create a room and put both users in it. + room_id = self.helper.create_room_as(user_id, tok=token) + self.helper.join(room_id, other_id, tok=other_token) + + # Update the user's push rules to care about reaction events. + self.get_success( + self.store.add_push_rule( + user_id, + "related_events", + priority_class=5, + conditions=[ + {"kind": "event_match", "key": "type", "pattern": "m.reaction"} + ], + actions=["notify"], + ) + ) + + def _create_event(type: str, content: JsonDict) -> str: + result = self.helper.send_event( + room_id, type=type, content=content, tok=other_token + ) + return result["event_id"] + + def _assert_counts(noitf_count: int, thread_notif_count: int) -> None: + counts = self.get_success( + self.store.db_pool.runInteraction( + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, + room_id, + user_id, + ) + ) + self.assertEqual( + counts.main_timeline, + NotifCounts( + notify_count=noitf_count, unread_count=0, highlight_count=0 + ), + ) + if thread_notif_count: + self.assertEqual( + counts.threads, + { + thread_id: NotifCounts( + notify_count=thread_notif_count, + unread_count=0, + highlight_count=0, + ), + }, + ) + else: + self.assertEqual(counts.threads, {}) + + # Create a root event. + thread_id = _create_event( + "m.room.message", {"msgtype": "m.text", "body": "msg"} + ) + _assert_counts(1, 0) + + # Reply, creating a thread. + reply_id = _create_event( + "m.room.message", + { + "msgtype": "m.text", + "body": "msg", + "m.relates_to": { + "rel_type": "m.thread", + "event_id": thread_id, + }, + }, + ) + _assert_counts(1, 1) + + # Create an event related to a thread event, this should still appear in + # the thread. + _create_event( + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": "m.annotation", + "event_id": reply_id, + "key": "A", + } + }, + ) + _assert_counts(1, 2) + def test_find_first_stream_ordering_after_ts(self) -> None: def add_event(so: int, ts: int) -> None: self.get_success( -- cgit 1.5.1 From e9a0419c8d28b8e153088073d6b76df6d7ed4ddf Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 6 Oct 2022 14:00:03 +0100 Subject: Fix sending events into rooms with non-integer power levels (#14073) --- changelog.d/14073.misc | 1 + mypy.ini | 3 ++ synapse/push/bulk_push_rule_evaluator.py | 9 +++- tests/push/test_bulk_push_rule_evaluator.py | 74 +++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14073.misc create mode 100644 tests/push/test_bulk_push_rule_evaluator.py (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/changelog.d/14073.misc b/changelog.d/14073.misc new file mode 100644 index 0000000000..7775500194 --- /dev/null +++ b/changelog.d/14073.misc @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.68.0 where messages could not be sent in rooms with non-integer `notifications` power level. diff --git a/mypy.ini b/mypy.ini index 64f9097206..34b4523e00 100644 --- a/mypy.ini +++ b/mypy.ini @@ -106,6 +106,9 @@ disallow_untyped_defs = False [mypy-tests.handlers.test_user_directory] disallow_untyped_defs = True +[mypy-tests.push.test_bulk_push_rule_evaluator] +disallow_untyped_defs = True + [mypy-tests.test_server] disallow_untyped_defs = True diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 4270438918..998354648f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -289,11 +289,18 @@ class BulkPushRuleEvaluator: if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id + # It's possible that old room versions have non-integer power levels (floats or + # strings). Workaround this by explicitly converting to int. + notification_levels = power_levels.get("notifications", {}) + if not event.room_version.msc3667_int_only_power_levels: + for user_id, level in notification_levels.items(): + notification_levels[user_id] = int(level) + evaluator = PushRuleEvaluator( _flatten_dict(event), room_member_count, sender_power_level, - power_levels.get("notifications", {}), + notification_levels, relations, self._relations_match_enabled, ) diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py new file mode 100644 index 0000000000..675d7df2ac --- /dev/null +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -0,0 +1,74 @@ +from unittest.mock import patch + +from synapse.api.room_versions import RoomVersions +from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator +from synapse.rest import admin +from synapse.rest.client import login, register, room +from synapse.types import create_requester + +from tests import unittest + + +class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + register.register_servlets, + ] + + def test_action_for_event_by_user_handles_noninteger_power_levels(self) -> None: + """We should convert floats and strings to integers before passing to Rust. + + Reproduces #14060. + + A lack of validation: the gift that keeps on giving. + """ + # Create a new user and room. + alice = self.register_user("alice", "pass") + token = self.login(alice, "pass") + + room_id = self.helper.create_room_as( + alice, room_version=RoomVersions.V9.identifier, tok=token + ) + + # Alter the power levels in that room to include stringy and floaty levels. + # We need to suppress the validation logic or else it will reject these dodgy + # values. (Presumably this validation was not always present.) + event_creation_handler = self.hs.get_event_creation_handler() + requester = create_requester(alice) + with patch("synapse.events.validator.validate_canonicaljson"), patch( + "synapse.events.validator.jsonschema.validate" + ): + self.helper.send_state( + room_id, + "m.room.power_levels", + { + "users": {alice: "100"}, # stringy + "notifications": {"room": 100.0}, # float + }, + token, + state_key="", + ) + + # Create a new message event, and try to evaluate it under the dodgy + # power level event. + event, context = self.get_success( + event_creation_handler.create_event( + requester, + { + "type": "m.room.message", + "room_id": room_id, + "content": { + "msgtype": "m.text", + "body": "helo", + }, + "sender": alice, + }, + ) + ) + + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + # should not raise + self.get_success(bulk_evaluator.action_for_event_by_user(event, context)) -- cgit 1.5.1 From 09be8ab5f9d54fa1a577d8b0028abf8acc28f30d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Oct 2022 06:26:39 -0400 Subject: Remove the experimental implementation of MSC3772. (#14094) MSC3772 has been abandoned. --- changelog.d/14094.removal | 1 + rust/src/push/base_rules.rs | 13 ---- rust/src/push/evaluator.rs | 105 +--------------------------- rust/src/push/mod.rs | 44 +++--------- stubs/synapse/synapse_rust/push.pyi | 6 +- synapse/config/experimental.py | 2 - synapse/push/bulk_push_rule_evaluator.py | 64 +---------------- synapse/storage/databases/main/cache.py | 3 - synapse/storage/databases/main/events.py | 5 -- synapse/storage/databases/main/push_rule.py | 15 ++-- synapse/storage/databases/main/relations.py | 53 -------------- tests/push/test_push_rule_evaluator.py | 76 +------------------- 12 files changed, 22 insertions(+), 365 deletions(-) create mode 100644 changelog.d/14094.removal (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/changelog.d/14094.removal b/changelog.d/14094.removal new file mode 100644 index 0000000000..6ef03b1a0f --- /dev/null +++ b/changelog.d/14094.removal @@ -0,0 +1 @@ +Remove the experimental implementation of [MSC3772](https://github.com/matrix-org/matrix-spec-proposals/pull/3772). diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 2a09cf99ae..63240cacfc 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -257,19 +257,6 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, - PushRule { - rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3772.thread_reply"), - priority_class: 1, - conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch { - rel_type: Cow::Borrowed("m.thread"), - event_type_pattern: None, - sender: None, - sender_type: Some(Cow::Borrowed("user_id")), - })]), - actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), - default: true, - default_enabled: true, - }, PushRule { rule_id: Cow::Borrowed("global/underride/.m.rule.message"), priority_class: 1, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index efe88ec76e..0365dd01dc 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - borrow::Cow, - collections::{BTreeMap, BTreeSet}, -}; +use std::collections::BTreeMap; use anyhow::{Context, Error}; use lazy_static::lazy_static; @@ -49,13 +46,6 @@ pub struct PushRuleEvaluator { /// The `notifications` section of the current power levels in the room. notification_power_levels: BTreeMap, - /// The relations related to the event as a mapping from relation type to - /// set of sender/event type 2-tuples. - relations: BTreeMap>, - - /// Is running "relation" conditions enabled? - relation_match_enabled: bool, - /// The power level of the sender of the event, or None if event is an /// outlier. sender_power_level: Option, @@ -70,8 +60,6 @@ impl PushRuleEvaluator { room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, - relations: BTreeMap>, - relation_match_enabled: bool, ) -> Result { let body = flattened_keys .get("content.body") @@ -83,8 +71,6 @@ impl PushRuleEvaluator { body, room_member_count, notification_power_levels, - relations, - relation_match_enabled, sender_power_level, }) } @@ -203,89 +189,11 @@ impl PushRuleEvaluator { false } } - KnownCondition::RelationMatch { - rel_type, - event_type_pattern, - sender, - sender_type, - } => { - self.match_relations(rel_type, sender, sender_type, user_id, event_type_pattern)? - } }; Ok(result) } - /// Evaluates a relation condition. - fn match_relations( - &self, - rel_type: &str, - sender: &Option>, - sender_type: &Option>, - user_id: Option<&str>, - event_type_pattern: &Option>, - ) -> Result { - // First check if relation matching is enabled... - if !self.relation_match_enabled { - return Ok(false); - } - - // ... and if there are any relations to match against. - let relations = if let Some(relations) = self.relations.get(rel_type) { - relations - } else { - return Ok(false); - }; - - // Extract the sender pattern from the condition - let sender_pattern = if let Some(sender) = sender { - Some(sender.as_ref()) - } else if let Some(sender_type) = sender_type { - if sender_type == "user_id" { - if let Some(user_id) = user_id { - Some(user_id) - } else { - return Ok(false); - } - } else { - warn!("Unrecognized sender_type: {sender_type}"); - return Ok(false); - } - } else { - None - }; - - let mut sender_compiled_pattern = if let Some(pattern) = sender_pattern { - Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) - } else { - None - }; - - let mut type_compiled_pattern = if let Some(pattern) = event_type_pattern { - Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) - } else { - None - }; - - for (relation_sender, event_type) in relations { - if let Some(pattern) = &mut sender_compiled_pattern { - if !pattern.is_match(relation_sender)? { - continue; - } - } - - if let Some(pattern) = &mut type_compiled_pattern { - if !pattern.is_match(event_type)? { - continue; - } - } - - return Ok(true); - } - - Ok(false) - } - /// Evaluates a `event_match` condition. fn match_event_match( &self, @@ -359,15 +267,8 @@ impl PushRuleEvaluator { fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); - let evaluator = PushRuleEvaluator::py_new( - flattened_keys, - 10, - Some(0), - BTreeMap::new(), - BTreeMap::new(), - true, - ) - .unwrap(); + let evaluator = + PushRuleEvaluator::py_new(flattened_keys, 10, Some(0), BTreeMap::new()).unwrap(); let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); assert_eq!(result.len(), 3); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 208b9c0d73..0dabfab8b8 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -275,16 +275,6 @@ pub enum KnownCondition { SenderNotificationPermission { key: Cow<'static, str>, }, - #[serde(rename = "org.matrix.msc3772.relation_match")] - RelationMatch { - rel_type: Cow<'static, str>, - #[serde(skip_serializing_if = "Option::is_none", rename = "type")] - event_type_pattern: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - sender: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - sender_type: Option>, - }, } impl IntoPy for Condition { @@ -401,21 +391,15 @@ impl PushRules { pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, - msc3772_enabled: bool, } #[pymethods] impl FilteredPushRules { #[new] - pub fn py_new( - push_rules: PushRules, - enabled_map: BTreeMap, - msc3772_enabled: bool, - ) -> Self { + pub fn py_new(push_rules: PushRules, enabled_map: BTreeMap) -> Self { Self { push_rules, enabled_map, - msc3772_enabled, } } @@ -430,25 +414,13 @@ impl FilteredPushRules { /// Iterates over all the rules and their enabled state, including base /// rules, in the order they should be executed in. fn iter(&self) -> impl Iterator { - self.push_rules - .iter() - .filter(|rule| { - // Ignore disabled experimental push rules - if !self.msc3772_enabled - && rule.rule_id == "global/underride/.org.matrix.msc3772.thread_reply" - { - return false; - } - - true - }) - .map(|r| { - let enabled = *self - .enabled_map - .get(&*r.rule_id) - .unwrap_or(&r.default_enabled); - (r, enabled) - }) + self.push_rules.iter().map(|r| { + let enabled = *self + .enabled_map + .get(&*r.rule_id) + .unwrap_or(&r.default_enabled); + (r, enabled) + }) } } diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 5900e61450..f2a61df660 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -25,9 +25,7 @@ class PushRules: def rules(self) -> Collection[PushRule]: ... class FilteredPushRules: - def __init__( - self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3772_enabled: bool - ): ... + def __init__(self, push_rules: PushRules, enabled_map: Dict[str, bool]): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def get_base_rule_ids() -> Collection[str]: ... @@ -39,8 +37,6 @@ class PushRuleEvaluator: room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], - relations: Mapping[str, Set[Tuple[str, str]]], - relation_match_enabled: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index e00cb7096c..f44655516e 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -95,8 +95,6 @@ class ExperimentalConfig(Config): # MSC2815 (allow room moderators to view redacted event content) self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False) - # MSC3772: A push rule for mutual relations. - self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index eced182fd5..8d94aeaa32 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from typing import ( TYPE_CHECKING, Any, Collection, Dict, - Iterable, List, Mapping, Optional, - Set, Tuple, Union, ) @@ -38,7 +35,7 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.state import StateFilter -from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator +from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state @@ -117,9 +114,6 @@ class BulkPushRuleEvaluator: resizable=False, ) - # Whether to support MSC3772 is supported. - self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled - async def _get_rules_for_event( self, event: EventBase, @@ -200,51 +194,6 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - async def _get_mutual_relations( - self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. - - Args: - parent_id: The event ID which is targeted by relations. - rules: The push rules which will be processed for this event. - - Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type - """ - - # If the experimental feature is not enabled, skip fetching relations. - if not self._relations_match_enabled: - return {} - - # Pre-filter to figure out which relation types are interesting. - rel_types = set() - for rule, enabled in rules: - if not enabled: - continue - - for condition in rule.conditions: - if condition["kind"] != "org.matrix.msc3772.relation_match": - continue - - # rel_type is required. - rel_type = condition.get("rel_type") - if rel_type: - rel_types.add(rel_type) - - # If no valid rules were found, no mutual relations. - if not rel_types: - return {} - - # If any valid rules were found, fetch the mutual relations. - return await self.store.get_mutual_event_relations(parent_id, rel_types) - @measure_func("action_for_event_by_user") async def action_for_event_by_user( self, event: EventBase, context: EventContext @@ -276,16 +225,11 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) + # Find the event's thread ID. relation = relation_from_event(event) - # If the event does not have a relation, then cannot have any mutual - # relations or thread ID. - relations = {} + # If the event does not have a relation, then it cannot have a thread ID. thread_id = MAIN_TIMELINE if relation: - relations = await self._get_mutual_relations( - relation.parent_id, - itertools.chain(*(r.rules() for r in rules_by_user.values())), - ) # Recursively attempt to find the thread this event relates to. if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id @@ -306,8 +250,6 @@ class BulkPushRuleEvaluator: room_member_count, sender_power_level, notification_levels, - relations, - self._relations_match_enabled, ) users = rules_by_user.keys() diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 3b8ed1f7ee..a9f25a5904 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,9 +259,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) - self._attempt_to_invalidate_cache( - "get_mutual_event_relations_for_rel_type", (relates_to,) - ) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 3e15827986..060fe71454 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2024,11 +2024,6 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) - self.store._invalidate_cache_and_stream( - txn, - self.store.get_mutual_event_relations_for_rel_type, - (redacted_relates_to,), - ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8295322b0e..51416b2236 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -29,7 +29,6 @@ from typing import ( ) from synapse.api.errors import StoreError -from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -63,9 +62,7 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], - enabled_map: Dict[str, bool], - experimental_config: ExperimentalConfig, + rawrules: List[JsonDict], enabled_map: Dict[str, bool] ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. @@ -83,9 +80,7 @@ def _load_rules( push_rules = PushRules(ruleslist) - filtered_rules = FilteredPushRules( - push_rules, enabled_map, msc3772_enabled=experimental_config.msc3772_enabled - ) + filtered_rules = FilteredPushRules(push_rules, enabled_map) return filtered_rules @@ -165,7 +160,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map, self.hs.config.experimental) + return _load_rules(rows, enabled_map) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -224,9 +219,7 @@ class PushRulesWorkerStore( results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): - results[user_id] = _load_rules( - rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental - ) + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) return results diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 116abef9de..6b7eec4bf2 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -776,59 +776,6 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - @cached(iterable=True) - async def get_mutual_event_relations_for_rel_type( - self, event_id: str, relation_type: str - ) -> Set[Tuple[str, str]]: - raise NotImplementedError() - - @cachedList( - cached_method_name="get_mutual_event_relations_for_rel_type", - list_name="relation_types", - ) - async def get_mutual_event_relations( - self, event_id: str, relation_types: Collection[str] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. - - Args: - event_id: The event ID which is targeted by relations. - relation_types: The relation types to check for mutual relations. - - Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type - """ - rel_type_sql, rel_type_args = make_in_list_sql_clause( - self.database_engine, "relation_type", relation_types - ) - - sql = f""" - SELECT DISTINCT relation_type, sender, type FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND {rel_type_sql} - """ - - def _get_event_relations( - txn: LoggingTransaction, - ) -> Dict[str, Set[Tuple[str, str]]]: - txn.execute(sql, [event_id] + rel_type_args) - result: Dict[str, Set[Tuple[str, str]]] = { - rel_type: set() for rel_type in relation_types - } - for rel_type, sender, type in txn.fetchall(): - result[rel_type].add((sender, type)) - return result - - return await self.db_pool.runInteraction( - "get_event_relations", _get_event_relations - ) - @cached() async def get_thread_id(self, event_id: str) -> Optional[str]: """ diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 8804f0e0d3..decf619466 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Set, Tuple, Union +from typing import Dict, Optional, Union import frozendict @@ -38,12 +38,7 @@ from tests.test_utils.event_injection import create_event, inject_member_event class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator( - self, - content: JsonDict, - relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None, - relations_match_enabled: bool = False, - ) -> PushRuleEvaluator: + def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator: event = FrozenEvent( { "event_id": "$event_id", @@ -63,8 +58,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count, sender_power_level, power_levels.get("notifications", {}), - relations or {}, - relations_match_enabled, ) def test_display_name(self) -> None: @@ -299,71 +292,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): {"sound": "default", "highlight": True}, ) - def test_relation_match(self) -> None: - """Test the relation_match push rule kind.""" - - # Check if the experimental feature is disabled. - evaluator = self._get_evaluator( - {}, {"m.annotation": {("@user:test", "m.reaction")}} - ) - - # A push rule evaluator with the experimental rule enabled. - evaluator = self._get_evaluator( - {}, {"m.annotation": {("@user:test", "m.reaction")}}, True - ) - - # Check just relation type. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check relation type and sender. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@user:test", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@other:test", - } - self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - - # Check relation type and event type. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "type": "m.reaction", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check just sender, this fails since rel_type is required. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "sender": "@user:test", - } - self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - - # Check sender glob. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@*:test", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check event type glob. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "event_type": "*.reaction", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator""" -- cgit 1.5.1 From 87099b6ea5cb48b03d2007c46af80bc3f0767519 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Oct 2022 12:15:52 -0400 Subject: Return the main timeline for events which are not part of a thread. (#14140) Fixes a bug where threaded receipts could not be sent for the main timeline. --- changelog.d/14140.feature | 1 + synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/storage/databases/main/relations.py | 12 +++++++----- 3 files changed, 9 insertions(+), 6 deletions(-) create mode 100644 changelog.d/14140.feature (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/changelog.d/14140.feature b/changelog.d/14140.feature new file mode 100644 index 0000000000..5d0ae16e13 --- /dev/null +++ b/changelog.d/14140.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8d94aeaa32..a75386f6a0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -236,7 +236,7 @@ class BulkPushRuleEvaluator: else: # Since the event has not yet been persisted we check whether # the parent is part of a thread. - thread_id = await self.store.get_thread_id(relation.parent_id) or "main" + thread_id = await self.store.get_thread_id(relation.parent_id) # It's possible that old room versions have non-integer power levels (floats or # strings). Workaround this by explicitly converting to int. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 6b7eec4bf2..e7fbf950e6 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -28,7 +28,7 @@ from typing import ( import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import MAIN_TIMELINE, RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause @@ -777,7 +777,7 @@ class RelationsWorkerStore(SQLBaseStore): ) @cached() - async def get_thread_id(self, event_id: str) -> Optional[str]: + async def get_thread_id(self, event_id: str) -> str: """ Get the thread ID for an event. This considers multi-level relations, e.g. an annotation to an event which is part of a thread. @@ -787,7 +787,7 @@ class RelationsWorkerStore(SQLBaseStore): Returns: The event ID of the root event in the thread, if this event is part - of a thread. None, otherwise. + of a thread. "main", otherwise. """ # Since event relations form a tree, we should only ever find 0 or 1 # results from the below query. @@ -802,13 +802,15 @@ class RelationsWorkerStore(SQLBaseStore): ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; """ - def _get_thread_id(txn: LoggingTransaction) -> Optional[str]: + def _get_thread_id(txn: LoggingTransaction) -> str: txn.execute(sql, (event_id,)) # TODO Should we ensure there's only a single result here? row = txn.fetchone() if row: return row[0] - return None + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) -- cgit 1.5.1 From b7a7ff6ee39da4981dcfdce61bf8ac4735e3d047 Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 21 Oct 2022 10:46:22 -0700 Subject: Add initial power level event to batch of bulk persisted events when creating a new room. (#14228) --- changelog.d/14228.misc | 1 + synapse/handlers/federation.py | 4 +- synapse/handlers/federation_event.py | 4 +- synapse/handlers/message.py | 14 ++---- synapse/handlers/room.py | 39 ++++----------- synapse/push/bulk_push_rule_evaluator.py | 74 ++++++++++++++++++++++++----- tests/push/test_bulk_push_rule_evaluator.py | 2 +- tests/replication/_base.py | 2 +- 8 files changed, 82 insertions(+), 58 deletions(-) create mode 100644 changelog.d/14228.misc (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/changelog.d/14228.misc b/changelog.d/14228.misc new file mode 100644 index 0000000000..14fe31a8bc --- /dev/null +++ b/changelog.d/14228.misc @@ -0,0 +1 @@ +Add initial power level event to batch of bulk persisted events when creating a new room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 275a37a575..4fbc79a6cb 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1017,7 +1017,9 @@ class FederationHandler: context = EventContext.for_outlier(self._storage_controllers) - await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] + ) try: await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 06e41b5cc0..7da6316a82 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2171,8 +2171,8 @@ class FederationEventHandler: min_depth, ) else: - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] ) try: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 15b828dd74..468900a07f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1433,17 +1433,9 @@ class EventCreationHandler: a room that has been un-partial stated. """ - for event, context in events_and_context: - # Skip push notification actions for historical messages - # because we don't want to notify people about old history back in time. - # The historical messages also do not have the proper `context.current_state_ids` - # and `state_groups` because they have `prev_events` that aren't persisted yet - # (historical messages persisted in reverse-chronological order). - if not event.internal_metadata.is_historical(): - with opentracing.start_active_span("calculate_push_actions"): - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context - ) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + events_and_context + ) try: # If we're a worker we need to hit out to the master. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 638f54051a..cc1e5c8f97 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1055,9 +1055,6 @@ class RoomCreationHandler: event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} depth = 1 - # the last event sent/persisted to the db - last_sent_event_id: Optional[str] = None - # the most recently created event prev_event: List[str] = [] # a map of event types, state keys -> event_ids. We collect these mappings this as events are @@ -1102,26 +1099,6 @@ class RoomCreationHandler: return new_event, new_context - async def send( - event: EventBase, - context: synapse.events.snapshot.EventContext, - creator: Requester, - ) -> int: - nonlocal last_sent_event_id - - ev = await self.event_creation_handler.handle_new_client_event( - requester=creator, - events_and_context=[(event, context)], - ratelimit=False, - ignore_shadow_ban=True, - ) - - last_sent_event_id = ev.event_id - - # we know it was persisted, so must have a stream ordering - assert ev.internal_metadata.stream_ordering - return ev.internal_metadata.stream_ordering - try: config = self._presets_dict[preset_config] except KeyError: @@ -1135,10 +1112,14 @@ class RoomCreationHandler: ) logger.debug("Sending %s in new room", EventTypes.Member) - await send(creation_event, creation_context, creator) + ev = await self.event_creation_handler.handle_new_client_event( + requester=creator, + events_and_context=[(creation_event, creation_context)], + ratelimit=False, + ignore_shadow_ban=True, + ) + last_sent_event_id = ev.event_id - # Room create event must exist at this point - assert last_sent_event_id is not None member_event_id, _ = await self.room_member_handler.update_membership( creator, creator.user, @@ -1157,6 +1138,7 @@ class RoomCreationHandler: depth += 1 state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id + events_to_send = [] # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) @@ -1165,7 +1147,7 @@ class RoomCreationHandler: EventTypes.PowerLevels, pl_content, False ) current_state_group = power_context._state_group - await send(power_event, power_context, creator) + events_to_send.append((power_event, power_context)) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1214,9 +1196,8 @@ class RoomCreationHandler: False, ) current_state_group = pl_context._state_group - await send(pl_event, pl_context, creator) + events_to_send.append((pl_event, pl_context)) - events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: room_alias_event, room_alias_context = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index a75386f6a0..d7795a9080 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -165,8 +165,21 @@ class BulkPushRuleEvaluator: return rules_by_user async def _get_power_levels_and_sender_level( - self, event: EventBase, context: EventContext + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], ) -> Tuple[dict, Optional[int]]: + """ + Given an event and an event context, get the power level event relevant to the event + and the power level of the sender of the event. + Args: + event: event to check + context: context of event to check + event_id_to_event: a mapping of event_id to event for a set of events being + batch persisted. This is needed as the sought-after power level event may + be in this batch rather than the DB + """ # There are no power levels and sender levels possible to get from outlier if event.internal_metadata.is_outlier(): return {}, None @@ -177,15 +190,26 @@ class BulkPushRuleEvaluator: ) pl_event_id = prev_state_ids.get(POWER_KEY) + # fastpath: if there's a power level event, that's all we need, and + # not having a power level event is an extreme edge case if pl_event_id: - # fastpath: if there's a power level event, that's all we need, and - # not having a power level event is an extreme edge case - auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} + # Get the power level event from the batch, or fall back to the database. + pl_event = event_id_to_event.get(pl_event_id) + if pl_event: + auth_events = {POWER_KEY: pl_event} + else: + auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} else: auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events_dict = await self.store.get_events(auth_events_ids) + # Some needed auth events might be in the batch, combine them with those + # fetched from the database. + for auth_event_id in auth_events_ids: + auth_event = event_id_to_event.get(auth_event_id) + if auth_event: + auth_events_dict[auth_event_id] = auth_event auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -194,16 +218,38 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - @measure_func("action_for_event_by_user") - async def action_for_event_by_user( - self, event: EventBase, context: EventContext + async def action_for_events_by_user( + self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: - """Given an event and context, evaluate the push rules, check if the message - should increment the unread count, and insert the results into the - event_push_actions_staging table. + """Given a list of events and their associated contexts, evaluate the push rules + for each event, check if the message should increment the unread count, and + insert the results into the event_push_actions_staging table. """ - if not event.internal_metadata.is_notifiable(): - # Push rules for events that aren't notifiable can't be processed by this + # For batched events the power level events may not have been persisted yet, + # so we pass in the batched events. Thus if the event cannot be found in the + # database we can check in the batch. + event_id_to_event = {e.event_id: e for e, _ in events_and_context} + for event, context in events_and_context: + await self._action_for_event_by_user(event, context, event_id_to_event) + + @measure_func("action_for_event_by_user") + async def _action_for_event_by_user( + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], + ) -> None: + + if ( + not event.internal_metadata.is_notifiable() + or event.internal_metadata.is_historical() + ): + # Push rules for events that aren't notifiable can't be processed by this and + # we want to skip push notification actions for historical messages + # because we don't want to notify people about old history back in time. + # The historical messages also do not have the proper `context.current_state_ids` + # and `state_groups` because they have `prev_events` that aren't persisted yet + # (historical messages persisted in reverse-chronological order). return # Disable counting as unread unless the experimental configuration is @@ -223,7 +269,9 @@ class BulkPushRuleEvaluator: ( power_levels, sender_power_level, - ) = await self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level( + event, context, event_id_to_event + ) # Find the event's thread ID. relation = relation_from_event(event) diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 675d7df2ac..594e7937a8 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -71,4 +71,4 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise - self.get_success(bulk_evaluator.action_for_event_by_user(event, context)) + self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index ce53f808db..121f3d8d65 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -371,7 +371,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config=worker_hs.config.server.listeners[0], resource=resource, server_version_string="1", - max_request_body_size=4096, + max_request_body_size=8192, reactor=self.reactor, ) -- cgit 1.5.1 From 2d0ba3f89aaf9545d81c4027500e543ec70b68a6 Mon Sep 17 00:00:00 2001 From: "DeepBlueV7.X" Date: Tue, 25 Oct 2022 13:38:01 +0000 Subject: Implementation for MSC3664: Pushrules for relations (#11804) --- changelog.d/11804.feature | 1 + rust/src/push/base_rules.rs | 17 +++ rust/src/push/evaluator.rs | 99 ++++++++++++- rust/src/push/mod.rs | 61 ++++++-- stubs/synapse/synapse_rust/push.pyi | 6 +- synapse/config/experimental.py | 3 + synapse/push/bulk_push_rule_evaluator.py | 49 ++++++- synapse/rest/client/capabilities.py | 5 + synapse/storage/databases/main/push_rule.py | 15 +- tests/push/test_push_rule_evaluator.py | 215 +++++++++++++++++++++++++++- 10 files changed, 454 insertions(+), 17 deletions(-) create mode 100644 changelog.d/11804.feature (limited to 'synapse/push/bulk_push_rule_evaluator.py') diff --git a/changelog.d/11804.feature b/changelog.d/11804.feature new file mode 100644 index 0000000000..6420393541 --- /dev/null +++ b/changelog.d/11804.feature @@ -0,0 +1 @@ +Implement [MSC3664](https://github.com/matrix-org/matrix-doc/pull/3664). Contributed by Nico. diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 63240cacfc..49802fa4eb 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -25,6 +25,7 @@ use crate::push::Action; use crate::push::Condition; use crate::push::EventMatchCondition; use crate::push::PushRule; +use crate::push::RelatedEventMatchCondition; use crate::push::SetTweak; use crate::push::TweakValue; @@ -114,6 +115,22 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed("global/override/.im.nheko.msc3664.reply"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelatedEventMatch( + RelatedEventMatchCondition { + key: Some(Cow::Borrowed("sender")), + pattern: None, + pattern_type: Some(Cow::Borrowed("user_id")), + rel_type: Cow::Borrowed("m.in_reply_to"), + include_fallbacks: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"), priority_class: 5, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 0365dd01dc..cedd42c54d 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -23,6 +23,7 @@ use regex::Regex; use super::{ utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType}, Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition, + RelatedEventMatchCondition, }; lazy_static! { @@ -49,6 +50,13 @@ pub struct PushRuleEvaluator { /// The power level of the sender of the event, or None if event is an /// outlier. sender_power_level: Option, + + /// The related events, indexed by relation type. Flattened in the same manner as + /// `flattened_keys`. + related_events_flattened: BTreeMap>, + + /// If msc3664, push rules for related events, is enabled. + related_event_match_enabled: bool, } #[pymethods] @@ -60,6 +68,8 @@ impl PushRuleEvaluator { room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, + related_events_flattened: BTreeMap>, + related_event_match_enabled: bool, ) -> Result { let body = flattened_keys .get("content.body") @@ -72,6 +82,8 @@ impl PushRuleEvaluator { room_member_count, notification_power_levels, sender_power_level, + related_events_flattened, + related_event_match_enabled, }) } @@ -156,6 +168,9 @@ impl PushRuleEvaluator { KnownCondition::EventMatch(event_match) => { self.match_event_match(event_match, user_id)? } + KnownCondition::RelatedEventMatch(event_match) => { + self.match_related_event_match(event_match, user_id)? + } KnownCondition::ContainsDisplayName => { if let Some(dn) = display_name { if !dn.is_empty() { @@ -239,6 +254,79 @@ impl PushRuleEvaluator { compiled_pattern.is_match(haystack) } + /// Evaluates a `related_event_match` condition. (MSC3664) + fn match_related_event_match( + &self, + event_match: &RelatedEventMatchCondition, + user_id: Option<&str>, + ) -> Result { + // First check if related event matching is enabled... + if !self.related_event_match_enabled { + return Ok(false); + } + + // get the related event, fail if there is none. + let event = if let Some(event) = self.related_events_flattened.get(&*event_match.rel_type) { + event + } else { + return Ok(false); + }; + + // If we are not matching fallbacks, don't match if our special key indicating this is a + // fallback relation is not present. + if !event_match.include_fallbacks.unwrap_or(false) + && event.contains_key("im.vector.is_falling_back") + { + return Ok(false); + } + + // if we have no key, accept the event as matching, if it existed without matching any + // fields. + let key = if let Some(key) = &event_match.key { + key + } else { + return Ok(true); + }; + + let pattern = if let Some(pattern) = &event_match.pattern { + pattern + } else if let Some(pattern_type) = &event_match.pattern_type { + // The `pattern_type` can either be "user_id" or "user_localpart", + // either way if we don't have a `user_id` then the condition can't + // match. + let user_id = if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + }; + + match &**pattern_type { + "user_id" => user_id, + "user_localpart" => get_localpart_from_id(user_id)?, + _ => return Ok(false), + } + } else { + return Ok(false); + }; + + let haystack = if let Some(haystack) = event.get(&**key) { + haystack + } else { + return Ok(false); + }; + + // For the content.body we match against "words", but for everything + // else we match against the entire value. + let match_type = if key == "content.body" { + GlobMatchType::Word + } else { + GlobMatchType::Whole + }; + + let mut compiled_pattern = get_glob_matcher(pattern, match_type)?; + compiled_pattern.is_match(haystack) + } + /// Match the member count against an 'is' condition /// The `is` condition can be things like '>2', '==3' or even just '4'. fn match_member_count(&self, is: &str) -> Result { @@ -267,8 +355,15 @@ impl PushRuleEvaluator { fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); - let evaluator = - PushRuleEvaluator::py_new(flattened_keys, 10, Some(0), BTreeMap::new()).unwrap(); + let evaluator = PushRuleEvaluator::py_new( + flattened_keys, + 10, + Some(0), + BTreeMap::new(), + BTreeMap::new(), + true, + ) + .unwrap(); let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); assert_eq!(result.len(), 3); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 0dabfab8b8..d57800aa4a 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -267,6 +267,8 @@ pub enum Condition { #[serde(tag = "kind")] pub enum KnownCondition { EventMatch(EventMatchCondition), + #[serde(rename = "im.nheko.msc3664.related_event_match")] + RelatedEventMatch(RelatedEventMatchCondition), ContainsDisplayName, RoomMemberCount { #[serde(skip_serializing_if = "Option::is_none")] @@ -299,6 +301,20 @@ pub struct EventMatchCondition { pub pattern_type: Option>, } +/// The body of a [`Condition::RelatedEventMatch`] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RelatedEventMatchCondition { + #[serde(skip_serializing_if = "Option::is_none")] + pub key: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub pattern: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub pattern_type: Option>, + pub rel_type: Cow<'static, str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub include_fallbacks: Option, +} + /// The collection of push rules for a user. #[derive(Debug, Clone, Default)] #[pyclass(frozen)] @@ -391,15 +407,21 @@ impl PushRules { pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, + msc3664_enabled: bool, } #[pymethods] impl FilteredPushRules { #[new] - pub fn py_new(push_rules: PushRules, enabled_map: BTreeMap) -> Self { + pub fn py_new( + push_rules: PushRules, + enabled_map: BTreeMap, + msc3664_enabled: bool, + ) -> Self { Self { push_rules, enabled_map, + msc3664_enabled, } } @@ -414,13 +436,25 @@ impl FilteredPushRules { /// Iterates over all the rules and their enabled state, including base /// rules, in the order they should be executed in. fn iter(&self) -> impl Iterator { - self.push_rules.iter().map(|r| { - let enabled = *self - .enabled_map - .get(&*r.rule_id) - .unwrap_or(&r.default_enabled); - (r, enabled) - }) + self.push_rules + .iter() + .filter(|rule| { + // Ignore disabled experimental push rules + if !self.msc3664_enabled + && rule.rule_id == "global/override/.im.nheko.msc3664.reply" + { + return false; + } + + true + }) + .map(|r| { + let enabled = *self + .enabled_map + .get(&*r.rule_id) + .unwrap_or(&r.default_enabled); + (r, enabled) + }) } } @@ -446,6 +480,17 @@ fn test_deserialize_condition() { let _: Condition = serde_json::from_str(json).unwrap(); } +#[test] +fn test_deserialize_unstable_msc3664_condition() { + let json = r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern":"coffee","rel_type":"m.in_reply_to"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::RelatedEventMatch(_)) + )); +} + #[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index f2a61df660..f3b6d6c933 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -25,7 +25,9 @@ class PushRules: def rules(self) -> Collection[PushRule]: ... class FilteredPushRules: - def __init__(self, push_rules: PushRules, enabled_map: Dict[str, bool]): ... + def __init__( + self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3664_enabled: bool + ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def get_base_rule_ids() -> Collection[str]: ... @@ -37,6 +39,8 @@ class PushRuleEvaluator: room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], + related_events_flattened: Mapping[str, Mapping[str, str]], + related_event_match_enabled: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 4009add01d..d9bdd66d55 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -98,6 +98,9 @@ class ExperimentalConfig(Config): # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) + # MSC3664: Pushrules to match on related events + self.msc3664_enabled: bool = experimental.get("msc3664_enabled", False) + # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d7795a9080..75b7e126ca 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -45,7 +45,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) - push_rules_invalidation_counter = Counter( "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "" ) @@ -107,6 +106,8 @@ class BulkPushRuleEvaluator: self.clock = hs.get_clock() self._event_auth_handler = hs.get_event_auth_handler() + self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled + self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", @@ -218,6 +219,48 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level + async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]: + """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation + + Returns: + Mapping of relation type to flattened events. + """ + related_events: Dict[str, Dict[str, str]] = {} + if self._related_event_match_enabled: + related_event_id = event.content.get("m.relates_to", {}).get("event_id") + relation_type = event.content.get("m.relates_to", {}).get("rel_type") + if related_event_id is not None and relation_type is not None: + related_event = await self.store.get_event( + related_event_id, allow_none=True + ) + if related_event is not None: + related_events[relation_type] = _flatten_dict(related_event) + + reply_event_id = ( + event.content.get("m.relates_to", {}) + .get("m.in_reply_to", {}) + .get("event_id") + ) + + # convert replies to pseudo relations + if reply_event_id is not None: + related_event = await self.store.get_event( + reply_event_id, allow_none=True + ) + + if related_event is not None: + related_events["m.in_reply_to"] = _flatten_dict(related_event) + + # indicate that this is from a fallback relation. + if relation_type == "m.thread" and event.content.get( + "m.relates_to", {} + ).get("is_falling_back", False): + related_events["m.in_reply_to"][ + "im.vector.is_falling_back" + ] = "" + + return related_events + async def action_for_events_by_user( self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: @@ -286,6 +329,8 @@ class BulkPushRuleEvaluator: # the parent is part of a thread. thread_id = await self.store.get_thread_id(relation.parent_id) + related_events = await self._related_events(event) + # It's possible that old room versions have non-integer power levels (floats or # strings). Workaround this by explicitly converting to int. notification_levels = power_levels.get("notifications", {}) @@ -298,6 +343,8 @@ class BulkPushRuleEvaluator: room_member_count, sender_power_level, notification_levels, + related_events, + self._related_event_match_enabled, ) users = rules_by_user.keys() diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 4237071c61..e84dde31b1 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -77,6 +77,11 @@ class CapabilitiesRestServlet(RestServlet): "enabled": True, } + if self.config.experimental.msc3664_enabled: + response["capabilities"]["im.nheko.msc3664.related_event_match"] = { + "enabled": self.config.experimental.msc3664_enabled, + } + return HTTPStatus.OK, response diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 51416b2236..b6c15f29f8 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -29,6 +29,7 @@ from typing import ( ) from synapse.api.errors import StoreError +from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -62,7 +63,9 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], enabled_map: Dict[str, bool] + rawrules: List[JsonDict], + enabled_map: Dict[str, bool], + experimental_config: ExperimentalConfig, ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. @@ -80,7 +83,9 @@ def _load_rules( push_rules = PushRules(ruleslist) - filtered_rules = FilteredPushRules(push_rules, enabled_map) + filtered_rules = FilteredPushRules( + push_rules, enabled_map, msc3664_enabled=experimental_config.msc3664_enabled + ) return filtered_rules @@ -160,7 +165,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map) + return _load_rules(rows, enabled_map, self.hs.config.experimental) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -219,7 +224,9 @@ class PushRulesWorkerStore( results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): - results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) + results[user_id] = _load_rules( + rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental + ) return results diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index decf619466..fe7c145840 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -38,7 +38,9 @@ from tests.test_utils.event_injection import create_event, inject_member_event class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator: + def _get_evaluator( + self, content: JsonDict, related_events=None + ) -> PushRuleEvaluator: event = FrozenEvent( { "event_id": "$event_id", @@ -58,6 +60,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count, sender_power_level, power_levels.get("notifications", {}), + {} if related_events is None else related_events, + True, ) def test_display_name(self) -> None: @@ -292,6 +296,215 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): {"sound": "default", "highlight": True}, ) + def test_related_event_match(self): + evaluator = self._get_evaluator( + { + "m.relates_to": { + "event_id": "$parent_event_id", + "key": "😀", + "rel_type": "m.annotation", + "m.in_reply_to": { + "event_id": "$parent_event_id", + }, + } + }, + { + "m.in_reply_to": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + "m.annotation": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + }, + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@user:test", + }, + "@other_user:test", + "display_name", + ) + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.annotation", + "pattern": "@other_user:test", + }, + "@other_user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.replace", + }, + "@other_user:test", + "display_name", + ) + ) + + def test_related_event_match_with_fallback(self): + evaluator = self._get_evaluator( + { + "m.relates_to": { + "event_id": "$parent_event_id", + "key": "😀", + "rel_type": "m.thread", + "is_falling_back": True, + "m.in_reply_to": { + "event_id": "$parent_event_id", + }, + } + }, + { + "m.in_reply_to": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + "im.vector.is_falling_back": "", + }, + "m.thread": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + }, + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + "include_fallbacks": True, + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + "include_fallbacks": False, + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + + def test_related_event_match_no_related_event(self): + evaluator = self._get_evaluator( + {"msgtype": "m.text", "body": "Message without related event"} + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator""" -- cgit 1.5.1