From 8ae42ab8fa3c6b52d74c24daa7ca75a478fa4fbb Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 21 Sep 2022 15:39:01 +0100 Subject: Support enabling/disabling pushers (from MSC3881) (#13799) Partial implementation of MSC3881 --- tests/push/test_email.py | 4 +- tests/push/test_http.py | 148 ++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 137 insertions(+), 15 deletions(-) (limited to 'tests/push') diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 7a3b0d6755..fd14568f55 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase): ) self.pusher = self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, access_token=self.token_id, kind="email", @@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase): """ with self.assertRaises(SynapseError) as cm: self.get_success_or_raise( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, access_token=self.token_id, kind="email", diff --git a/tests/push/test_http.py b/tests/push/test_http.py index d9c68cdd2d..af67d84463 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -19,8 +19,8 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable -from synapse.push import PusherConfigException -from synapse.rest.client import login, push_rule, receipts, room +from synapse.push import PusherConfig, PusherConfigException +from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -35,6 +35,7 @@ class HTTPPusherTests(HomeserverTestCase): login.register_servlets, receipts.register_servlets, push_rule.register_servlets, + pusher.register_servlets, ] user_id = True hijack_auth = False @@ -74,7 +75,7 @@ class HTTPPusherTests(HomeserverTestCase): def test_data(data: Optional[JsonDict]) -> None: self.get_failure( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -119,7 +120,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -235,7 +236,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -355,7 +356,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -441,7 +442,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -518,7 +519,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -624,7 +625,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -728,18 +729,38 @@ class HTTPPusherTests(HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.json_body) - def _make_user_with_pusher(self, username: str) -> Tuple[str, str]: + def _make_user_with_pusher( + self, username: str, enabled: bool = True + ) -> Tuple[str, str]: + """Registers a user and creates a pusher for them. + + Args: + username: the localpart of the new user's Matrix ID. + enabled: whether to create the pusher in an enabled or disabled state. + """ user_id = self.register_user(username, "pass") access_token = self.login(username, "pass") # Register the pusher + self._set_pusher(user_id, access_token, enabled) + + return user_id, access_token + + def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None: + """Creates or updates the pusher for the given user. + + Args: + user_id: the user's Matrix ID. + access_token: the access token associated with the pusher. + enabled: whether to enable or disable the pusher. + """ user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -749,11 +770,10 @@ class HTTPPusherTests(HomeserverTestCase): pushkey="a@example.com", lang=None, data={"url": "http://example.com/_matrix/push/v1/notify"}, + enabled=enabled, ) ) - return user_id, access_token - def test_dont_notify_rule_overrides_message(self) -> None: """ The override push rule will suppress notification @@ -791,3 +811,105 @@ class HTTPPusherTests(HomeserverTestCase): # The user sends a message back (sends a notification) self.helper.send(room, body="Hello", tok=access_token) self.assertEqual(len(self.push_attempts), 1) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_disable(self) -> None: + """Tests that disabling a pusher means it's not pushed to anymore.""" + user_id, access_token = self._make_user_with_pusher("user") + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + room = self.helper.create_room_as(user_id, tok=access_token) + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Send a message and check that it generated a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Disable the pusher. + self._set_pusher(user_id, access_token, enabled=False) + + # Send another message and check that it did not generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Get the pushers for the user and check that it is marked as disabled. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + + enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"] + self.assertFalse(enabled) + self.assertTrue(isinstance(enabled, bool)) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_enable(self) -> None: + """Tests that enabling a disabled pusher means it gets pushed to.""" + # Create the user with the pusher already disabled. + user_id, access_token = self._make_user_with_pusher("user", enabled=False) + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + room = self.helper.create_room_as(user_id, tok=access_token) + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Send a message and check that it did not generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 0) + + # Enable the pusher. + self._set_pusher(user_id, access_token, enabled=True) + + # Send another message and check that it did generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Get the pushers for the user and check that it is marked as enabled. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + + enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"] + self.assertTrue(enabled) + self.assertTrue(isinstance(enabled, bool)) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_null_enabled(self) -> None: + """Tests that a pusher that has an 'enabled' column set to NULL (eg pushers + created before the column was introduced) is considered enabled. + """ + # We intentionally set 'enabled' to None so that it's stored as NULL in the + # database. + user_id, access_token = self._make_user_with_pusher("user", enabled=None) # type: ignore[arg-type] + + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]) + + def test_update_different_device_access_token(self) -> None: + """Tests that if we create a pusher from one device, the update it from another + device, the access token associated with the pusher stays the same. + """ + # Create a user with a pusher. + user_id, access_token = self._make_user_with_pusher("user") + + # Get the token ID for the current access token, since that's what we store in + # the pushers table. + user_tuple = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token(access_token) + ) + token_id = user_tuple.token_id + + # Generate a new access token, and update the pusher with it. + new_token = self.login("user", "pass") + self._set_pusher(user_id, new_token, enabled=False) + + # Get the current list of pushers for the user. + ret = self.get_success( + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + ) + pushers: List[PusherConfig] = list(ret) + + # Check that we still have one pusher, and that the access token associated with + # it didn't change. + self.assertEqual(len(pushers), 1) + self.assertEqual(pushers[0].access_token, token_id) -- cgit 1.5.1 From ccca14140a019c2e0430f95d78fa075efd8d535f Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 21 Sep 2022 16:31:53 +0100 Subject: Track device IDs for pushers (#13831) Second half of the MSC3881 implementation --- changelog.d/13831.feature | 1 + synapse/push/__init__.py | 2 + synapse/push/pusherpool.py | 10 ++- synapse/rest/client/pusher.py | 3 + synapse/storage/databases/main/pusher.py | 73 +++++++++++++++++++++- .../schema/main/delta/73/03pusher_device_id.sql | 20 ++++++ tests/push/test_http.py | 55 ++++++++++++++-- 7 files changed, 154 insertions(+), 10 deletions(-) create mode 100644 changelog.d/13831.feature create mode 100644 synapse/storage/schema/main/delta/73/03pusher_device_id.sql (limited to 'tests/push') diff --git a/changelog.d/13831.feature b/changelog.d/13831.feature new file mode 100644 index 0000000000..6c8e5cffe2 --- /dev/null +++ b/changelog.d/13831.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881). diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index ac99d35a7e..a0c760239d 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -117,6 +117,7 @@ class PusherConfig: last_success: Optional[int] failing_since: Optional[int] enabled: bool + device_id: Optional[str] def as_dict(self) -> Dict[str, Any]: """Information that can be retrieved about a pusher after creation.""" @@ -130,6 +131,7 @@ class PusherConfig: "profile_tag": self.profile_tag, "pushkey": self.pushkey, "enabled": self.enabled, + "device_id": self.device_id, } diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 2597898cf4..e2648cbc93 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -107,6 +107,7 @@ class PusherPool: data: JsonDict, profile_tag: str = "", enabled: bool = True, + device_id: Optional[str] = None, ) -> Optional[Pusher]: """Creates a new pusher and adds it to the pool @@ -149,18 +150,20 @@ class PusherPool: last_success=None, failing_since=None, enabled=enabled, + device_id=device_id, ) ) # Before we actually persist the pusher, we check if the user already has one - # for this app ID and pushkey. If so, we want to keep the access token in place, - # since this could be one device modifying (e.g. enabling/disabling) another - # device's pusher. + # this app ID and pushkey. If so, we want to keep the access token and device ID + # in place, since this could be one device modifying (e.g. enabling/disabling) + # another device's pusher. existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( user_id, app_id, pushkey ) if existing_config: access_token = existing_config.access_token + device_id = existing_config.device_id await self.store.add_pusher( user_id=user_id, @@ -176,6 +179,7 @@ class PusherPool: last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, enabled=enabled, + device_id=device_id, ) pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id) diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index c9f76125dc..975eef2144 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -57,7 +57,9 @@ class PushersRestServlet(RestServlet): for pusher in pusher_dicts: if self._msc3881_enabled: pusher["org.matrix.msc3881.enabled"] = pusher["enabled"] + pusher["org.matrix.msc3881.device_id"] = pusher["device_id"] del pusher["enabled"] + del pusher["device_id"] return 200, {"pushers": pusher_dicts} @@ -134,6 +136,7 @@ class PushersSetRestServlet(RestServlet): data=content["data"], profile_tag=content.get("profile_tag", ""), enabled=enabled, + device_id=requester.device_id, ) except PusherConfigException as pce: raise SynapseError( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index ee55b8c4a9..01206950a9 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -124,7 +124,7 @@ class PusherWorkerStore(SQLBaseStore): id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_stream_ordering, last_success, failing_since, - COALESCE(enabled, TRUE) AS enabled + COALESCE(enabled, TRUE) AS enabled, device_id FROM pushers """ @@ -477,7 +477,74 @@ class PusherWorkerStore(SQLBaseStore): return number_deleted -class PusherStore(PusherWorkerStore): +class PusherBackgroundUpdatesStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + "set_device_id_for_pushers", self._set_device_id_for_pushers + ) + + async def _set_device_id_for_pushers( + self, progress: JsonDict, batch_size: int + ) -> int: + """Background update to populate the device_id column of the pushers table.""" + last_pusher_id = progress.get("pusher_id", 0) + + def set_device_id_for_pushers_txn(txn: LoggingTransaction) -> int: + txn.execute( + """ + SELECT p.id, at.device_id + FROM pushers AS p + INNER JOIN access_tokens AS at + ON p.access_token = at.id + WHERE + p.access_token IS NOT NULL + AND at.device_id IS NOT NULL + AND p.id > ? + ORDER BY p.id + LIMIT ? + """, + (last_pusher_id, batch_size), + ) + + rows = self.db_pool.cursor_to_dict(txn) + if len(rows) == 0: + return 0 + + self.db_pool.simple_update_many_txn( + txn=txn, + table="pushers", + key_names=("id",), + key_values=[(row["id"],) for row in rows], + value_names=("device_id",), + value_values=[(row["device_id"],) for row in rows], + ) + + self.db_pool.updates._background_update_progress_txn( + txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["id"]} + ) + + return len(rows) + + nb_processed = await self.db_pool.runInteraction( + "set_device_id_for_pushers", set_device_id_for_pushers_txn + ) + + if nb_processed < batch_size: + await self.db_pool.updates._end_background_update( + "set_device_id_for_pushers" + ) + + return nb_processed + + +class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() @@ -496,6 +563,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering: int, profile_tag: str = "", enabled: bool = True, + device_id: Optional[str] = None, ) -> None: async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on @@ -515,6 +583,7 @@ class PusherStore(PusherWorkerStore): "profile_tag": profile_tag, "id": stream_id, "enabled": enabled, + "device_id": device_id, }, desc="add_pusher", lock=False, diff --git a/synapse/storage/schema/main/delta/73/03pusher_device_id.sql b/synapse/storage/schema/main/delta/73/03pusher_device_id.sql new file mode 100644 index 0000000000..1b4ffbeebe --- /dev/null +++ b/synapse/storage/schema/main/delta/73/03pusher_device_id.sql @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a device_id column to track the device ID that created the pusher. It's NULLable +-- on purpose, because a) it might not be possible to track down the device that created +-- old pushers (pushers.access_token and access_tokens.device_id are both NULLable), and +-- b) access tokens retrieved via the admin API don't have a device associated to them. +ALTER TABLE pushers ADD COLUMN device_id TEXT; \ No newline at end of file diff --git a/tests/push/test_http.py b/tests/push/test_http.py index af67d84463..b383b8401f 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -22,6 +22,7 @@ from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfig, PusherConfigException from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.server import HomeServer +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import JsonDict from synapse.util import Clock @@ -771,6 +772,7 @@ class HTTPPusherTests(HomeserverTestCase): lang=None, data={"url": "http://example.com/_matrix/push/v1/notify"}, enabled=enabled, + device_id=user_tuple.device_id, ) ) @@ -885,19 +887,21 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(channel.json_body["pushers"]), 1) self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]) - def test_update_different_device_access_token(self) -> None: + def test_update_different_device_access_token_device_id(self) -> None: """Tests that if we create a pusher from one device, the update it from another - device, the access token associated with the pusher stays the same. + device, the access token and device ID associated with the pusher stays the + same. """ # Create a user with a pusher. user_id, access_token = self._make_user_with_pusher("user") # Get the token ID for the current access token, since that's what we store in - # the pushers table. + # the pushers table. Also get the device ID from it. user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id + device_id = user_tuple.device_id # Generate a new access token, and update the pusher with it. new_token = self.login("user", "pass") @@ -909,7 +913,48 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers: List[PusherConfig] = list(ret) - # Check that we still have one pusher, and that the access token associated with - # it didn't change. + # Check that we still have one pusher, and that the access token and device ID + # associated with it didn't change. self.assertEqual(len(pushers), 1) self.assertEqual(pushers[0].access_token, token_id) + self.assertEqual(pushers[0].device_id, device_id) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_device_id(self) -> None: + """Tests that a pusher created with a given device ID shows that device ID in + GET /pushers requests. + """ + self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # We create the pusher with an HTTP request rather than with + # _make_user_with_pusher so that we can test the device ID is correctly set when + # creating a pusher via an API call. + self.make_request( + method="POST", + path="/pushers/set", + content={ + "kind": "http", + "app_id": "m.http", + "app_display_name": "HTTP Push Notifications", + "device_display_name": "pushy push", + "pushkey": "a@example.com", + "lang": "en", + "data": {"url": "http://example.com/_matrix/push/v1/notify"}, + }, + access_token=access_token, + ) + + # Look up the user info for the access token so we can compare the device ID. + lookup_result: TokenLookupResult = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token(access_token) + ) + + # Get the user's devices and check it has the correct device ID. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + self.assertEqual( + channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"], + lookup_result.device_id, + ) -- cgit 1.5.1 From ebd9e2dac6495a1857617d1a76c9259a988f8bb4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 29 Sep 2022 16:12:09 +0100 Subject: Implement push rule evaluation in Rust. (#13838) --- changelog.d/13838.misc | 1 + rust/Cargo.toml | 4 +- rust/benches/evaluator.rs | 149 ++++++++++++ rust/benches/glob.rs | 40 ++++ rust/build.rs | 2 +- rust/src/push/base_rules.rs | 1 + rust/src/push/evaluator.rs | 374 +++++++++++++++++++++++++++++++ rust/src/push/mod.rs | 28 ++- rust/src/push/utils.rs | 215 ++++++++++++++++++ stubs/synapse/synapse_rust/push.pyi | 19 +- synapse/push/bulk_push_rule_evaluator.py | 44 ++-- synapse/push/httppusher.py | 39 +++- synapse/push/push_rule_evaluator.py | 361 ----------------------------- tests/push/test_push_rule_evaluator.py | 20 +- 14 files changed, 894 insertions(+), 403 deletions(-) create mode 100644 changelog.d/13838.misc create mode 100644 rust/benches/evaluator.rs create mode 100644 rust/benches/glob.rs create mode 100644 rust/src/push/evaluator.rs create mode 100644 rust/src/push/utils.rs delete mode 100644 synapse/push/push_rule_evaluator.py (limited to 'tests/push') diff --git a/changelog.d/13838.misc b/changelog.d/13838.misc new file mode 100644 index 0000000000..28bddb7059 --- /dev/null +++ b/changelog.d/13838.misc @@ -0,0 +1 @@ +Port push rules to using Rust. diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 44263bf77e..cffaa5b51b 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -11,7 +11,9 @@ rust-version = "1.58.1" [lib] name = "synapse" -crate-type = ["cdylib"] +# We generate a `cdylib` for Python and a standard `lib` for running +# tests/benchmarks. +crate-type = ["lib", "cdylib"] [package.metadata.maturin] # This is where we tell maturin where to place the built library. diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs new file mode 100644 index 0000000000..ed411461d1 --- /dev/null +++ b/rust/benches/evaluator.rs @@ -0,0 +1,149 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![feature(test)] +use synapse::push::{ + evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules, +}; +use test::Bencher; + +extern crate test; + +#[bench] +fn bench_match_exact(b: &mut Bencher) { + let flattened_keys = [ + ("type".to_string(), "m.text".to_string()), + ("room_id".to_string(), "!room:server".to_string()), + ("content.body".to_string(), "test message".to_string()), + ] + .into_iter() + .collect(); + + let eval = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + Default::default(), + Default::default(), + true, + ) + .unwrap(); + + let condition = Condition::Known(synapse::push::KnownCondition::EventMatch( + EventMatchCondition { + key: "room_id".into(), + pattern: Some("!room:server".into()), + pattern_type: None, + }, + )); + + let matched = eval.match_condition(&condition, None, None).unwrap(); + assert!(matched, "Didn't match"); + + b.iter(|| eval.match_condition(&condition, None, None).unwrap()); +} + +#[bench] +fn bench_match_word(b: &mut Bencher) { + let flattened_keys = [ + ("type".to_string(), "m.text".to_string()), + ("room_id".to_string(), "!room:server".to_string()), + ("content.body".to_string(), "test message".to_string()), + ] + .into_iter() + .collect(); + + let eval = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + Default::default(), + Default::default(), + true, + ) + .unwrap(); + + let condition = Condition::Known(synapse::push::KnownCondition::EventMatch( + EventMatchCondition { + key: "content.body".into(), + pattern: Some("test".into()), + pattern_type: None, + }, + )); + + let matched = eval.match_condition(&condition, None, None).unwrap(); + assert!(matched, "Didn't match"); + + b.iter(|| eval.match_condition(&condition, None, None).unwrap()); +} + +#[bench] +fn bench_match_word_miss(b: &mut Bencher) { + let flattened_keys = [ + ("type".to_string(), "m.text".to_string()), + ("room_id".to_string(), "!room:server".to_string()), + ("content.body".to_string(), "test message".to_string()), + ] + .into_iter() + .collect(); + + let eval = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + Default::default(), + Default::default(), + true, + ) + .unwrap(); + + let condition = Condition::Known(synapse::push::KnownCondition::EventMatch( + EventMatchCondition { + key: "content.body".into(), + pattern: Some("foobar".into()), + pattern_type: None, + }, + )); + + let matched = eval.match_condition(&condition, None, None).unwrap(); + assert!(!matched, "Didn't match"); + + b.iter(|| eval.match_condition(&condition, None, None).unwrap()); +} + +#[bench] +fn bench_eval_message(b: &mut Bencher) { + let flattened_keys = [ + ("type".to_string(), "m.text".to_string()), + ("room_id".to_string(), "!room:server".to_string()), + ("content.body".to_string(), "test message".to_string()), + ] + .into_iter() + .collect(); + + let eval = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + Default::default(), + Default::default(), + true, + ) + .unwrap(); + + let rules = + FilteredPushRules::py_new(PushRules::new(Vec::new()), Default::default(), false, false); + + b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); +} diff --git a/rust/benches/glob.rs b/rust/benches/glob.rs new file mode 100644 index 0000000000..b6697d9285 --- /dev/null +++ b/rust/benches/glob.rs @@ -0,0 +1,40 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![feature(test)] + +use synapse::push::utils::{glob_to_regex, GlobMatchType}; +use test::Bencher; + +extern crate test; + +#[bench] +fn bench_whole(b: &mut Bencher) { + b.iter(|| glob_to_regex("test", GlobMatchType::Whole)); +} + +#[bench] +fn bench_word(b: &mut Bencher) { + b.iter(|| glob_to_regex("test", GlobMatchType::Word)); +} + +#[bench] +fn bench_whole_wildcard_run(b: &mut Bencher) { + b.iter(|| glob_to_regex("test***??*?*?foo", GlobMatchType::Whole)); +} + +#[bench] +fn bench_word_wildcard_run(b: &mut Bencher) { + b.iter(|| glob_to_regex("test***??*?*?foo", GlobMatchType::Whole)); +} diff --git a/rust/build.rs b/rust/build.rs index 2117975e56..ef370e6b41 100644 --- a/rust/build.rs +++ b/rust/build.rs @@ -22,7 +22,7 @@ fn main() -> Result<(), std::io::Error> { for entry in entries { if entry.is_dir() { - dirs.push(entry) + dirs.push(entry); } else { paths.push(entry.to_str().expect("valid rust paths").to_string()); } diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 7c62bc4849..bb59676bde 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -262,6 +262,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ priority_class: 1, conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch { rel_type: Cow::Borrowed("m.thread"), + event_type_pattern: None, sender: None, sender_type: Some(Cow::Borrowed("user_id")), })]), diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs new file mode 100644 index 0000000000..efe88ec76e --- /dev/null +++ b/rust/src/push/evaluator.rs @@ -0,0 +1,374 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + borrow::Cow, + collections::{BTreeMap, BTreeSet}, +}; + +use anyhow::{Context, Error}; +use lazy_static::lazy_static; +use log::warn; +use pyo3::prelude::*; +use regex::Regex; + +use super::{ + utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType}, + Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition, +}; + +lazy_static! { + /// Used to parse the `is` clause in the room member count condition. + static ref INEQUALITY_EXPR: Regex = Regex::new(r"^([=<>]*)([0-9]+)$").expect("valid regex"); +} + +/// Allows running a set of push rules against a particular event. +#[pyclass] +pub struct PushRuleEvaluator { + /// A mapping of "flattened" keys to string values in the event, e.g. + /// includes things like "type" and "content.msgtype". + flattened_keys: BTreeMap, + + /// The "content.body", if any. + body: String, + + /// The number of users in the room. + room_member_count: u64, + + /// The `notifications` section of the current power levels in the room. + notification_power_levels: BTreeMap, + + /// The relations related to the event as a mapping from relation type to + /// set of sender/event type 2-tuples. + relations: BTreeMap>, + + /// Is running "relation" conditions enabled? + relation_match_enabled: bool, + + /// The power level of the sender of the event, or None if event is an + /// outlier. + sender_power_level: Option, +} + +#[pymethods] +impl PushRuleEvaluator { + /// Create a new `PushRuleEvaluator`. See struct docstring for details. + #[new] + pub fn py_new( + flattened_keys: BTreeMap, + room_member_count: u64, + sender_power_level: Option, + notification_power_levels: BTreeMap, + relations: BTreeMap>, + relation_match_enabled: bool, + ) -> Result { + let body = flattened_keys + .get("content.body") + .cloned() + .unwrap_or_default(); + + Ok(PushRuleEvaluator { + flattened_keys, + body, + room_member_count, + notification_power_levels, + relations, + relation_match_enabled, + sender_power_level, + }) + } + + /// Run the evaluator with the given push rules, for the given user ID and + /// display name of the user. + /// + /// Passing in None will skip evaluating rules matching user ID and display + /// name. + /// + /// Returns the set of actions, if any, that match (filtering out any + /// `dont_notify` actions). + pub fn run( + &self, + push_rules: &FilteredPushRules, + user_id: Option<&str>, + display_name: Option<&str>, + ) -> Vec { + 'outer: for (push_rule, enabled) in push_rules.iter() { + if !enabled { + continue; + } + + for condition in push_rule.conditions.iter() { + match self.match_condition(condition, user_id, display_name) { + Ok(true) => {} + Ok(false) => continue 'outer, + Err(err) => { + warn!("Condition match failed {err}"); + continue 'outer; + } + } + } + + let actions = push_rule + .actions + .iter() + // Filter out "dont_notify" actions, as we don't store them. + .filter(|a| **a != Action::DontNotify) + .cloned() + .collect(); + + return actions; + } + + Vec::new() + } + + /// Check if the given condition matches. + fn matches( + &self, + condition: Condition, + user_id: Option<&str>, + display_name: Option<&str>, + ) -> bool { + match self.match_condition(&condition, user_id, display_name) { + Ok(true) => true, + Ok(false) => false, + Err(err) => { + warn!("Condition match failed {err}"); + false + } + } + } +} + +impl PushRuleEvaluator { + /// Match a given `Condition` for a push rule. + pub fn match_condition( + &self, + condition: &Condition, + user_id: Option<&str>, + display_name: Option<&str>, + ) -> Result { + let known_condition = match condition { + Condition::Known(known) => known, + Condition::Unknown(_) => { + return Ok(false); + } + }; + + let result = match known_condition { + KnownCondition::EventMatch(event_match) => { + self.match_event_match(event_match, user_id)? + } + KnownCondition::ContainsDisplayName => { + if let Some(dn) = display_name { + if !dn.is_empty() { + get_glob_matcher(dn, GlobMatchType::Word)?.is_match(&self.body)? + } else { + // We specifically ignore empty display names, as otherwise + // they would always match. + false + } + } else { + false + } + } + KnownCondition::RoomMemberCount { is } => { + if let Some(is) = is { + self.match_member_count(is)? + } else { + false + } + } + KnownCondition::SenderNotificationPermission { key } => { + if let Some(sender_power_level) = &self.sender_power_level { + let required_level = self + .notification_power_levels + .get(key.as_ref()) + .copied() + .unwrap_or(50); + + *sender_power_level >= required_level + } else { + false + } + } + KnownCondition::RelationMatch { + rel_type, + event_type_pattern, + sender, + sender_type, + } => { + self.match_relations(rel_type, sender, sender_type, user_id, event_type_pattern)? + } + }; + + Ok(result) + } + + /// Evaluates a relation condition. + fn match_relations( + &self, + rel_type: &str, + sender: &Option>, + sender_type: &Option>, + user_id: Option<&str>, + event_type_pattern: &Option>, + ) -> Result { + // First check if relation matching is enabled... + if !self.relation_match_enabled { + return Ok(false); + } + + // ... and if there are any relations to match against. + let relations = if let Some(relations) = self.relations.get(rel_type) { + relations + } else { + return Ok(false); + }; + + // Extract the sender pattern from the condition + let sender_pattern = if let Some(sender) = sender { + Some(sender.as_ref()) + } else if let Some(sender_type) = sender_type { + if sender_type == "user_id" { + if let Some(user_id) = user_id { + Some(user_id) + } else { + return Ok(false); + } + } else { + warn!("Unrecognized sender_type: {sender_type}"); + return Ok(false); + } + } else { + None + }; + + let mut sender_compiled_pattern = if let Some(pattern) = sender_pattern { + Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) + } else { + None + }; + + let mut type_compiled_pattern = if let Some(pattern) = event_type_pattern { + Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) + } else { + None + }; + + for (relation_sender, event_type) in relations { + if let Some(pattern) = &mut sender_compiled_pattern { + if !pattern.is_match(relation_sender)? { + continue; + } + } + + if let Some(pattern) = &mut type_compiled_pattern { + if !pattern.is_match(event_type)? { + continue; + } + } + + return Ok(true); + } + + Ok(false) + } + + /// Evaluates a `event_match` condition. + fn match_event_match( + &self, + event_match: &EventMatchCondition, + user_id: Option<&str>, + ) -> Result { + let pattern = if let Some(pattern) = &event_match.pattern { + pattern + } else if let Some(pattern_type) = &event_match.pattern_type { + // The `pattern_type` can either be "user_id" or "user_localpart", + // either way if we don't have a `user_id` then the condition can't + // match. + let user_id = if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + }; + + match &**pattern_type { + "user_id" => user_id, + "user_localpart" => get_localpart_from_id(user_id)?, + _ => return Ok(false), + } + } else { + return Ok(false); + }; + + let haystack = if let Some(haystack) = self.flattened_keys.get(&*event_match.key) { + haystack + } else { + return Ok(false); + }; + + // For the content.body we match against "words", but for everything + // else we match against the entire value. + let match_type = if event_match.key == "content.body" { + GlobMatchType::Word + } else { + GlobMatchType::Whole + }; + + let mut compiled_pattern = get_glob_matcher(pattern, match_type)?; + compiled_pattern.is_match(haystack) + } + + /// Match the member count against an 'is' condition + /// The `is` condition can be things like '>2', '==3' or even just '4'. + fn match_member_count(&self, is: &str) -> Result { + let captures = INEQUALITY_EXPR.captures(is).context("bad 'is' clause")?; + let ineq = captures.get(1).map_or("==", |m| m.as_str()); + let rhs: u64 = captures + .get(2) + .context("missing number")? + .as_str() + .parse()?; + + let matches = match ineq { + "" | "==" => self.room_member_count == rhs, + "<" => self.room_member_count < rhs, + ">" => self.room_member_count > rhs, + ">=" => self.room_member_count >= rhs, + "<=" => self.room_member_count <= rhs, + _ => false, + }; + + Ok(matches) + } +} + +#[test] +fn push_rule_evaluator() { + let mut flattened_keys = BTreeMap::new(); + flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); + let evaluator = PushRuleEvaluator::py_new( + flattened_keys, + 10, + Some(0), + BTreeMap::new(), + BTreeMap::new(), + true, + ) + .unwrap(); + + let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); + assert_eq!(result.len(), 3); +} diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index de6764e7c5..30fffc31ad 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -42,7 +42,6 @@ //! //! The set of "base rules" are the list of rules that every user has by default. A //! user can modify their copy of the push rules in one of three ways: -//! //! 1. Adding a new push rule of a certain kind //! 2. Changing the actions of a base rule //! 3. Enabling/disabling a base rule. @@ -58,12 +57,16 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use anyhow::{Context, Error}; use log::warn; use pyo3::prelude::*; -use pythonize::pythonize; +use pythonize::{depythonize, pythonize}; use serde::de::Error as _; use serde::{Deserialize, Serialize}; use serde_json::Value; +use self::evaluator::PushRuleEvaluator; + mod base_rules; +pub mod evaluator; +pub mod utils; /// Called when registering modules with python. pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { @@ -71,6 +74,7 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { child_module.add_class::()?; child_module.add_class::()?; child_module.add_class::()?; + child_module.add_class::()?; child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?; m.add_submodule(child_module)?; @@ -274,6 +278,8 @@ pub enum KnownCondition { #[serde(rename = "org.matrix.msc3772.relation_match")] RelationMatch { rel_type: Cow<'static, str>, + #[serde(skip_serializing_if = "Option::is_none", rename = "type")] + event_type_pattern: Option>, #[serde(skip_serializing_if = "Option::is_none")] sender: Option>, #[serde(skip_serializing_if = "Option::is_none")] @@ -287,20 +293,26 @@ impl IntoPy for Condition { } } +impl<'source> FromPyObject<'source> for Condition { + fn extract(ob: &'source PyAny) -> PyResult { + Ok(depythonize(ob)?) + } +} + /// The body of a [`Condition::EventMatch`] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct EventMatchCondition { - key: Cow<'static, str>, + pub key: Cow<'static, str>, #[serde(skip_serializing_if = "Option::is_none")] - pattern: Option>, + pub pattern: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pattern_type: Option>, + pub pattern_type: Option>, } /// The collection of push rules for a user. #[derive(Debug, Clone, Default)] #[pyclass(frozen)] -struct PushRules { +pub struct PushRules { /// Custom push rules that override a base rule. overridden_base_rules: HashMap, PushRule>, @@ -319,7 +331,7 @@ struct PushRules { #[pymethods] impl PushRules { #[new] - fn new(rules: Vec) -> PushRules { + pub fn new(rules: Vec) -> PushRules { let mut push_rules: PushRules = Default::default(); for rule in rules { @@ -396,7 +408,7 @@ pub struct FilteredPushRules { #[pymethods] impl FilteredPushRules { #[new] - fn py_new( + pub fn py_new( push_rules: PushRules, enabled_map: BTreeMap, msc3786_enabled: bool, diff --git a/rust/src/push/utils.rs b/rust/src/push/utils.rs new file mode 100644 index 0000000000..8759340473 --- /dev/null +++ b/rust/src/push/utils.rs @@ -0,0 +1,215 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::bail; +use anyhow::Context; +use anyhow::Error; +use lazy_static::lazy_static; +use regex; +use regex::Regex; +use regex::RegexBuilder; + +lazy_static! { + /// Matches runs of non-wildcard characters followed by wildcard characters. + static ref WILDCARD_RUN: Regex = Regex::new(r"([^\?\*]*)([\?\*]*)").expect("valid regex"); +} + +/// Extract the localpart from a Matrix style ID +pub(crate) fn get_localpart_from_id(id: &str) -> Result<&str, Error> { + let (localpart, _) = id + .split_once(':') + .with_context(|| format!("ID does not contain colon: {id}"))?; + + // We need to strip off the first character, which is the ID type. + if localpart.is_empty() { + bail!("Invalid ID {id}"); + } + + Ok(&localpart[1..]) +} + +/// Used by `glob_to_regex` to specify what to match the regex against. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum GlobMatchType { + /// The generated regex will match against the entire input. + Whole, + /// The generated regex will match against words. + Word, +} + +/// Convert a "glob" style expression to a regex, anchoring either to the entire +/// input or to individual words. +pub fn glob_to_regex(glob: &str, match_type: GlobMatchType) -> Result { + let mut chunks = Vec::new(); + + // Patterns with wildcards must be simplified to avoid performance cliffs + // - The glob `?**?**?` is equivalent to the glob `???*` + // - The glob `???*` is equivalent to the regex `.{3,}` + for captures in WILDCARD_RUN.captures_iter(glob) { + if let Some(chunk) = captures.get(1) { + chunks.push(regex::escape(chunk.as_str())); + } + + if let Some(wildcards) = captures.get(2) { + if wildcards.as_str() == "" { + continue; + } + + let question_marks = wildcards.as_str().chars().filter(|c| *c == '?').count(); + + if wildcards.as_str().contains('*') { + chunks.push(format!(".{{{question_marks},}}")); + } else { + chunks.push(format!(".{{{question_marks}}}")); + } + } + } + + let joined = chunks.join(""); + + let regex_str = match match_type { + GlobMatchType::Whole => format!(r"\A{joined}\z"), + + // `^|\W` and `\W|$` handle the case where `pattern` starts or ends with a non-word + // character. + GlobMatchType::Word => format!(r"(?:^|\b|\W){joined}(?:\b|\W|$)"), + }; + + Ok(RegexBuilder::new(®ex_str) + .case_insensitive(true) + .build()?) +} + +/// Compiles the glob into a `Matcher`. +pub fn get_glob_matcher(glob: &str, match_type: GlobMatchType) -> Result { + // There are a number of shortcuts we can make if the glob doesn't contain a + // wild card. + let matcher = if glob.contains(['*', '?']) { + let regex = glob_to_regex(glob, match_type)?; + Matcher::Regex(regex) + } else if match_type == GlobMatchType::Whole { + // If there aren't any wildcards and we're matching the whole thing, + // then we simply can do a case-insensitive string match. + Matcher::Whole(glob.to_lowercase()) + } else { + // Otherwise, if we're matching against words then can first check + // if the haystack contains the glob at all. + Matcher::Word { + word: glob.to_lowercase(), + regex: None, + } + }; + + Ok(matcher) +} + +/// Matches against a glob +pub enum Matcher { + /// Plain regex matching. + Regex(Regex), + + /// Case-insensitive equality. + Whole(String), + + /// Word matching. `regex` is a cache of calling [`glob_to_regex`] on word. + Word { word: String, regex: Option }, +} + +impl Matcher { + /// Checks if the glob matches the given haystack. + pub fn is_match(&mut self, haystack: &str) -> Result { + // We want to to do case-insensitive matching, so we convert to + // lowercase first. + let haystack = haystack.to_lowercase(); + + match self { + Matcher::Regex(regex) => Ok(regex.is_match(&haystack)), + Matcher::Whole(whole) => Ok(whole == &haystack), + Matcher::Word { word, regex } => { + // If we're looking for a literal word, then we first check if + // the haystack contains the word as a substring. + if !haystack.contains(&*word) { + return Ok(false); + } + + // If it does contain the word as a substring, then we need to + // check if it is an actual word by testing it against the regex. + let regex = if let Some(regex) = regex { + regex + } else { + let compiled_regex = glob_to_regex(word, GlobMatchType::Word)?; + regex.insert(compiled_regex) + }; + + Ok(regex.is_match(&haystack)) + } + } + } +} + +#[test] +fn test_get_domain_from_id() { + get_localpart_from_id("").unwrap_err(); + get_localpart_from_id(":").unwrap_err(); + get_localpart_from_id(":asd").unwrap_err(); + get_localpart_from_id("::as::asad").unwrap_err(); + + assert_eq!(get_localpart_from_id("@test:foo").unwrap(), "test"); + assert_eq!(get_localpart_from_id("@:").unwrap(), ""); + assert_eq!(get_localpart_from_id("@test:foo:907").unwrap(), "test"); +} + +#[test] +fn tset_glob() -> Result<(), Error> { + assert_eq!( + glob_to_regex("simple", GlobMatchType::Whole)?.as_str(), + r"\Asimple\z" + ); + assert_eq!( + glob_to_regex("simple*", GlobMatchType::Whole)?.as_str(), + r"\Asimple.{0,}\z" + ); + assert_eq!( + glob_to_regex("simple?", GlobMatchType::Whole)?.as_str(), + r"\Asimple.{1}\z" + ); + assert_eq!( + glob_to_regex("simple?*?*", GlobMatchType::Whole)?.as_str(), + r"\Asimple.{2,}\z" + ); + assert_eq!( + glob_to_regex("simple???", GlobMatchType::Whole)?.as_str(), + r"\Asimple.{3}\z" + ); + + assert_eq!( + glob_to_regex("escape.", GlobMatchType::Whole)?.as_str(), + r"\Aescape\.\z" + ); + + assert!(glob_to_regex("simple", GlobMatchType::Whole)?.is_match("simple")); + assert!(!glob_to_regex("simple", GlobMatchType::Whole)?.is_match("simples")); + assert!(glob_to_regex("simple*", GlobMatchType::Whole)?.is_match("simples")); + assert!(glob_to_regex("simple?", GlobMatchType::Whole)?.is_match("simples")); + assert!(glob_to_regex("simple*", GlobMatchType::Whole)?.is_match("simple")); + + assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("some simple.")); + assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("simple")); + assert!(!glob_to_regex("simple", GlobMatchType::Word)?.is_match("simples")); + + assert!(glob_to_regex("@user:foo", GlobMatchType::Word)?.is_match("Some @user:foo test")); + assert!(glob_to_regex("@user:foo", GlobMatchType::Word)?.is_match("@user:foo")); + + Ok(()) +} diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 93c4e69d42..fffb8419c6 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -1,4 +1,4 @@ -from typing import Any, Collection, Dict, Mapping, Sequence, Tuple, Union +from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union from synapse.types import JsonDict @@ -35,3 +35,20 @@ class FilteredPushRules: def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def get_base_rule_ids() -> Collection[str]: ... + +class PushRuleEvaluator: + def __init__( + self, + flattened_keys: Mapping[str, str], + room_member_count: int, + sender_power_level: Optional[int], + notification_power_levels: Mapping[str, int], + relations: Mapping[str, Set[Tuple[str, str]]], + relation_match_enabled: bool, + ): ... + def run( + self, + push_rules: FilteredPushRules, + user_id: Optional[str], + display_name: Optional[str], + ) -> Collection[dict]: ... diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 32313e3bcf..60f3129005 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -17,6 +17,7 @@ import itertools import logging from typing import ( TYPE_CHECKING, + Any, Collection, Dict, Iterable, @@ -37,13 +38,11 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.state import StateFilter -from synapse.synapse_rust.push import FilteredPushRules, PushRule +from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state -from .push_rule_evaluator import PushRuleEvaluatorForEvent - if TYPE_CHECKING: from synapse.server import HomeServer @@ -290,11 +289,11 @@ class BulkPushRuleEvaluator: if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id - evaluator = PushRuleEvaluatorForEvent( - event, + evaluator = PushRuleEvaluator( + _flatten_dict(event), room_member_count, sender_power_level, - power_levels, + power_levels.get("notifications", {}), relations, self._relations_match_enabled, ) @@ -338,17 +337,10 @@ class BulkPushRuleEvaluator: # current user, it'll be added to the dict later. actions_by_user[uid] = [] - for rule, enabled in rules.rules(): - if not enabled: - continue - - matches = evaluator.check_conditions(rule.conditions, uid, display_name) - if matches: - actions = [x for x in rule.actions if x != "dont_notify"] - if actions and "notify" in actions: - # Push rules say we should notify the user of this event - actions_by_user[uid] = actions - break + actions = evaluator.run(rules, uid, display_name) + if "notify" in actions: + # Push rules say we should notify the user of this event + actions_by_user[uid] = actions # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist @@ -365,3 +357,21 @@ MemberMap = Dict[str, Optional[EventIdMembership]] Rule = Dict[str, dict] RulesByUser = Dict[str, List[Rule]] StateGroup = Union[object, int] + + +def _flatten_dict( + d: Union[EventBase, Mapping[str, Any]], + prefix: Optional[List[str]] = None, + result: Optional[Dict[str, str]] = None, +) -> Dict[str, str]: + if prefix is None: + prefix = [] + if result is None: + result = {} + for key, value in d.items(): + if isinstance(value, str): + result[".".join(prefix + [key])] = value.lower() + elif isinstance(value, Mapping): + _flatten_dict(value, prefix=(prefix + [key]), result=result) + + return result diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index e96fb45e9f..b048b03a74 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union from prometheus_client import Counter @@ -28,7 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.storage.databases.main.event_push_actions import HttpPushAction -from . import push_rule_evaluator, push_tools +from . import push_tools if TYPE_CHECKING: from synapse.server import HomeServer @@ -56,6 +56,39 @@ http_badges_failed_counter = Counter( ) +def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]: + """ + Converts a list of actions into a `tweaks` dict (which can then be passed to + the push gateway). + + This function ignores all actions other than `set_tweak` actions, and treats + absent `value`s as `True`, which agrees with the only spec-defined treatment + of absent `value`s (namely, for `highlight` tweaks). + + Args: + actions: list of actions + e.g. [ + {"set_tweak": "a", "value": "AAA"}, + {"set_tweak": "b", "value": "BBB"}, + {"set_tweak": "highlight"}, + "notify" + ] + + Returns: + dictionary of tweaks for those actions + e.g. {"a": "AAA", "b": "BBB", "highlight": True} + """ + tweaks = {} + for a in actions: + if not isinstance(a, dict): + continue + if "set_tweak" in a: + # value is allowed to be absent in which case the value assumed + # should be True. + tweaks[a["set_tweak"]] = a.get("value", True) + return tweaks + + class HttpPusher(Pusher): INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes MAX_BACKOFF_SEC = 60 * 60 @@ -281,7 +314,7 @@ class HttpPusher(Pusher): if "notify" not in push_action.actions: return True - tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions) + tweaks = tweaks_for_actions(push_action.actions) badge = await push_tools.get_badge_count( self.hs.get_datastores().main, self.user_id, diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py deleted file mode 100644 index f8176c5a42..0000000000 --- a/synapse/push/push_rule_evaluator.py +++ /dev/null @@ -1,361 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2017 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re -from typing import ( - Any, - Dict, - List, - Mapping, - Optional, - Pattern, - Sequence, - Set, - Tuple, - Union, -) - -from matrix_common.regex import glob_to_regex, to_word_pattern - -from synapse.events import EventBase -from synapse.types import UserID -from synapse.util.caches.lrucache import LruCache - -logger = logging.getLogger(__name__) - - -GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]") -IS_GLOB = re.compile(r"[\?\*\[\]]") -INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") - - -def _room_member_count(condition: Mapping[str, Any], room_member_count: int) -> bool: - return _test_ineq_condition(condition, room_member_count) - - -def _sender_notification_permission( - condition: Mapping[str, Any], - sender_power_level: Optional[int], - power_levels: Dict[str, Union[int, Dict[str, int]]], -) -> bool: - if sender_power_level is None: - return False - - notif_level_key = condition.get("key") - if notif_level_key is None: - return False - - notif_levels = power_levels.get("notifications", {}) - assert isinstance(notif_levels, dict) - room_notif_level = notif_levels.get(notif_level_key, 50) - - return sender_power_level >= room_notif_level - - -def _test_ineq_condition(condition: Mapping[str, Any], number: int) -> bool: - if "is" not in condition: - return False - m = INEQUALITY_EXPR.match(condition["is"]) - if not m: - return False - ineq = m.group(1) - rhs = m.group(2) - if not rhs.isdigit(): - return False - rhs_int = int(rhs) - - if ineq == "" or ineq == "==": - return number == rhs_int - elif ineq == "<": - return number < rhs_int - elif ineq == ">": - return number > rhs_int - elif ineq == ">=": - return number >= rhs_int - elif ineq == "<=": - return number <= rhs_int - else: - return False - - -def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]: - """ - Converts a list of actions into a `tweaks` dict (which can then be passed to - the push gateway). - - This function ignores all actions other than `set_tweak` actions, and treats - absent `value`s as `True`, which agrees with the only spec-defined treatment - of absent `value`s (namely, for `highlight` tweaks). - - Args: - actions: list of actions - e.g. [ - {"set_tweak": "a", "value": "AAA"}, - {"set_tweak": "b", "value": "BBB"}, - {"set_tweak": "highlight"}, - "notify" - ] - - Returns: - dictionary of tweaks for those actions - e.g. {"a": "AAA", "b": "BBB", "highlight": True} - """ - tweaks = {} - for a in actions: - if not isinstance(a, dict): - continue - if "set_tweak" in a: - # value is allowed to be absent in which case the value assumed - # should be True. - tweaks[a["set_tweak"]] = a.get("value", True) - return tweaks - - -class PushRuleEvaluatorForEvent: - def __init__( - self, - event: EventBase, - room_member_count: int, - sender_power_level: Optional[int], - power_levels: Dict[str, Union[int, Dict[str, int]]], - relations: Dict[str, Set[Tuple[str, str]]], - relations_match_enabled: bool, - ): - self._event = event - self._room_member_count = room_member_count - self._sender_power_level = sender_power_level - self._power_levels = power_levels - self._relations = relations - self._relations_match_enabled = relations_match_enabled - - # Maps strings of e.g. 'content.body' -> event["content"]["body"] - self._value_cache = _flatten_dict(event) - - # Maps cache keys to final values. - self._condition_cache: Dict[str, bool] = {} - - def check_conditions( - self, conditions: Sequence[Mapping], uid: str, display_name: Optional[str] - ) -> bool: - """ - Returns true if a user's conditions/user ID/display name match the event. - - Args: - conditions: The user's conditions to match. - uid: The user's MXID. - display_name: The display name. - - Returns: - True if all conditions match the event, False otherwise. - """ - for cond in conditions: - _cache_key = cond.get("_cache_key", None) - if _cache_key: - res = self._condition_cache.get(_cache_key, None) - if res is False: - return False - elif res is True: - continue - - res = self.matches(cond, uid, display_name) - if _cache_key: - self._condition_cache[_cache_key] = bool(res) - - if not res: - return False - - return True - - def matches( - self, condition: Mapping[str, Any], user_id: str, display_name: Optional[str] - ) -> bool: - """ - Returns true if a user's condition/user ID/display name match the event. - - Args: - condition: The user's condition to match. - uid: The user's MXID. - display_name: The display name, or None if there is not one. - - Returns: - True if the condition matches the event, False otherwise. - """ - if condition["kind"] == "event_match": - return self._event_match(condition, user_id) - elif condition["kind"] == "contains_display_name": - return self._contains_display_name(display_name) - elif condition["kind"] == "room_member_count": - return _room_member_count(condition, self._room_member_count) - elif condition["kind"] == "sender_notification_permission": - return _sender_notification_permission( - condition, self._sender_power_level, self._power_levels - ) - elif ( - condition["kind"] == "org.matrix.msc3772.relation_match" - and self._relations_match_enabled - ): - return self._relation_match(condition, user_id) - else: - # XXX This looks incorrect -- we have reached an unknown condition - # kind and are unconditionally returning that it matches. Note - # that it seems possible to provide a condition to the /pushrules - # endpoint with an unknown kind, see _rule_tuple_from_request_object. - return True - - def _event_match(self, condition: Mapping, user_id: str) -> bool: - """ - Check an "event_match" push rule condition. - - Args: - condition: The "event_match" push rule condition to match. - user_id: The user's MXID. - - Returns: - True if the condition matches the event, False otherwise. - """ - pattern = condition.get("pattern", None) - - if not pattern: - pattern_type = condition.get("pattern_type", None) - if pattern_type == "user_id": - pattern = user_id - elif pattern_type == "user_localpart": - pattern = UserID.from_string(user_id).localpart - - if not pattern: - logger.warning("event_match condition with no pattern") - return False - - # XXX: optimisation: cache our pattern regexps - if condition["key"] == "content.body": - body = self._event.content.get("body", None) - if not body or not isinstance(body, str): - return False - - return _glob_matches(pattern, body, word_boundary=True) - else: - haystack = self._value_cache.get(condition["key"], None) - if haystack is None: - return False - - return _glob_matches(pattern, haystack) - - def _contains_display_name(self, display_name: Optional[str]) -> bool: - """ - Check an "event_match" push rule condition. - - Args: - display_name: The display name, or None if there is not one. - - Returns: - True if the display name is found in the event body, False otherwise. - """ - if not display_name: - return False - - body = self._event.content.get("body", None) - if not body or not isinstance(body, str): - return False - - # Similar to _glob_matches, but do not treat display_name as a glob. - r = regex_cache.get((display_name, False, True), None) - if not r: - r1 = re.escape(display_name) - r1 = to_word_pattern(r1) - r = re.compile(r1, flags=re.IGNORECASE) - regex_cache[(display_name, False, True)] = r - - return bool(r.search(body)) - - def _relation_match(self, condition: Mapping, user_id: str) -> bool: - """ - Check an "relation_match" push rule condition. - - Args: - condition: The "event_match" push rule condition to match. - user_id: The user's MXID. - - Returns: - True if the condition matches the event, False otherwise. - """ - rel_type = condition.get("rel_type") - if not rel_type: - logger.warning("relation_match condition missing rel_type") - return False - - sender_pattern = condition.get("sender") - if sender_pattern is None: - sender_type = condition.get("sender_type") - if sender_type == "user_id": - sender_pattern = user_id - type_pattern = condition.get("type") - - # If any other relations matches, return True. - for sender, event_type in self._relations.get(rel_type, ()): - if sender_pattern and not _glob_matches(sender_pattern, sender): - continue - if type_pattern and not _glob_matches(type_pattern, event_type): - continue - # All values must have matched. - return True - - # No relations matched. - return False - - -# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache( - 50000, "regex_push_cache" -) - - -def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: - """Tests if value matches glob. - - Args: - glob - value: String to test against glob. - word_boundary: Whether to match against word boundaries or entire - string. Defaults to False. - """ - - try: - r = regex_cache.get((glob, True, word_boundary), None) - if not r: - r = glob_to_regex(glob, word_boundary=word_boundary) - regex_cache[(glob, True, word_boundary)] = r - return bool(r.search(value)) - except re.error: - logger.warning("Failed to parse glob to regex: %r", glob) - return False - - -def _flatten_dict( - d: Union[EventBase, Mapping[str, Any]], - prefix: Optional[List[str]] = None, - result: Optional[Dict[str, str]] = None, -) -> Dict[str, str]: - if prefix is None: - prefix = [] - if result is None: - result = {} - for key, value in d.items(): - if isinstance(value, str): - result[".".join(prefix + [key])] = value.lower() - elif isinstance(value, Mapping): - _flatten_dict(value, prefix=(prefix + [key]), result=result) - - return result diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 718f489577..b8308cbc05 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -23,11 +23,12 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.appservice import ApplicationService from synapse.events import FrozenEvent -from synapse.push import push_rule_evaluator -from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent +from synapse.push.bulk_push_rule_evaluator import _flatten_dict +from synapse.push.httppusher import tweaks_for_actions from synapse.rest.client import login, register, room from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex +from synapse.synapse_rust.push import PushRuleEvaluator from synapse.types import JsonDict from synapse.util import Clock @@ -41,7 +42,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): content: JsonDict, relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None, relations_match_enabled: bool = False, - ) -> PushRuleEvaluatorForEvent: + ) -> PushRuleEvaluator: event = FrozenEvent( { "event_id": "$event_id", @@ -56,12 +57,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count = 0 sender_power_level = 0 power_levels: Dict[str, Union[int, Dict[str, int]]] = {} - return PushRuleEvaluatorForEvent( - event, + return PushRuleEvaluator( + _flatten_dict(event), room_member_count, sender_power_level, - power_levels, - relations or set(), + power_levels.get("notifications", {}), + relations or {}, relations_match_enabled, ) @@ -293,7 +294,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): ] self.assertEqual( - push_rule_evaluator.tweaks_for_actions(actions), + tweaks_for_actions(actions), {"sound": "default", "highlight": True}, ) @@ -304,9 +305,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): evaluator = self._get_evaluator( {}, {"m.annotation": {("@user:test", "m.reaction")}} ) - condition = {"kind": "relation_match"} - # Oddly, an unknown condition always matches. - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) # A push rule evaluator with the experimental rule enabled. evaluator = self._get_evaluator( -- cgit 1.5.1 From 285b9e9b6c3558718e7d4f513062e277948ac35d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 30 Sep 2022 14:27:00 +0100 Subject: Speed up calculating push actions in large rooms (#13973) We move the expensive check of visibility to after calculating push actions, avoiding the expensive check for users who won't get pushed anyway. I think this should have a big impact on rooms with large numbers of local users that have pushed disabled. --- changelog.d/13973.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 25 ++++++---- tests/push/test_push_rule_evaluator.py | 82 +++++++++++++++++++++++++++++++- 3 files changed, 96 insertions(+), 12 deletions(-) create mode 100644 changelog.d/13973.misc (limited to 'tests/push') diff --git a/changelog.d/13973.misc b/changelog.d/13973.misc new file mode 100644 index 0000000000..58150a2b35 --- /dev/null +++ b/changelog.d/13973.misc @@ -0,0 +1 @@ +Speed up calculating push actions in large rooms. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 60f3129005..7bfe380543 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -303,20 +303,10 @@ class BulkPushRuleEvaluator: event.room_id, users ) - # This is a check for the case where user joins a room without being - # allowed to see history, and then the server receives a delayed event - # from before the user joined, which they should not be pushed for - uids_with_visibility = await filter_event_for_clients_with_state( - self.store, users, event, context - ) - for uid, rules in rules_by_user.items(): if event.sender == uid: continue - if uid not in uids_with_visibility: - continue - display_name = None profile = profiles.get(uid) if profile: @@ -342,6 +332,21 @@ class BulkPushRuleEvaluator: # Push rules say we should notify the user of this event actions_by_user[uid] = actions + # This is a check for the case where user joins a room without being + # allowed to see history, and then the server receives a delayed event + # from before the user joined, which they should not be pushed for + # + # We do this *after* calculating the push actions as a) its unlikely + # that we'll filter anyone out and b) for large rooms its likely that + # most users will have push disabled and so the set of users to check is + # much smaller. + uids_with_visibility = await filter_event_for_clients_with_state( + self.store, actions_by_user.keys(), event, context + ) + + for user_id in set(actions_by_user).difference(uids_with_visibility): + actions_by_user.pop(user_id, None) + # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index b8308cbc05..8804f0e0d3 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -19,17 +19,18 @@ import frozendict from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.room_versions import RoomVersions from synapse.appservice import ApplicationService from synapse.events import FrozenEvent from synapse.push.bulk_push_rule_evaluator import _flatten_dict from synapse.push.httppusher import tweaks_for_actions +from synapse.rest import admin from synapse.rest.client import login, register, room from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.synapse_rust.push import PushRuleEvaluator -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import Clock from tests import unittest @@ -437,3 +438,80 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): ) self.assertEqual(len(users_with_push_actions), 0) + + +class BulkPushRuleEvaluatorTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.main_store = homeserver.get_datastores().main + + self.user_id1 = self.register_user("user1", "password") + self.tok1 = self.login(self.user_id1, "password") + self.user_id2 = self.register_user("user2", "password") + self.tok2 = self.login(self.user_id2, "password") + + self.room_id = self.helper.create_room_as(tok=self.tok1) + + # We want to test history visibility works correctly. + self.helper.send_state( + self.room_id, + EventTypes.RoomHistoryVisibility, + {"history_visibility": HistoryVisibility.JOINED}, + tok=self.tok1, + ) + + def get_notif_count(self, user_id: str) -> int: + return self.get_success( + self.main_store.db_pool.simple_select_one_onecol( + table="event_push_actions", + keyvalues={"user_id": user_id}, + retcol="COALESCE(SUM(notif), 0)", + desc="get_staging_notif_count", + ) + ) + + def test_plain_message(self) -> None: + """Test that sending a normal message in a room will trigger a + notification + """ + + # Have user2 join the room and cle + self.helper.join(self.room_id, self.user_id2, tok=self.tok2) + + # They start off with no notifications, but get them when messages are + # sent. + self.assertEqual(self.get_notif_count(self.user_id2), 0) + + user1 = UserID.from_string(self.user_id1) + self.create_and_send_event(self.room_id, user1) + + self.assertEqual(self.get_notif_count(self.user_id2), 1) + + def test_delayed_message(self) -> None: + """Test that a delayed message that was from before a user joined + doesn't cause a notification for the joined user. + """ + user1 = UserID.from_string(self.user_id1) + + # Send a message before user2 joins + event_id1 = self.create_and_send_event(self.room_id, user1) + + # Have user2 join the room + self.helper.join(self.room_id, self.user_id2, tok=self.tok2) + + # They start off with no notifications + self.assertEqual(self.get_notif_count(self.user_id2), 0) + + # Send another message that references the event before the join to + # simulate a "delayed" event + self.create_and_send_event(self.room_id, user1, prev_event_ids=[event_id1]) + + # user2 should not be notified about it, because they can't see it. + self.assertEqual(self.get_notif_count(self.user_id2), 0) -- cgit 1.5.1 From e9a0419c8d28b8e153088073d6b76df6d7ed4ddf Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 6 Oct 2022 14:00:03 +0100 Subject: Fix sending events into rooms with non-integer power levels (#14073) --- changelog.d/14073.misc | 1 + mypy.ini | 3 ++ synapse/push/bulk_push_rule_evaluator.py | 9 +++- tests/push/test_bulk_push_rule_evaluator.py | 74 +++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14073.misc create mode 100644 tests/push/test_bulk_push_rule_evaluator.py (limited to 'tests/push') diff --git a/changelog.d/14073.misc b/changelog.d/14073.misc new file mode 100644 index 0000000000..7775500194 --- /dev/null +++ b/changelog.d/14073.misc @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.68.0 where messages could not be sent in rooms with non-integer `notifications` power level. diff --git a/mypy.ini b/mypy.ini index 64f9097206..34b4523e00 100644 --- a/mypy.ini +++ b/mypy.ini @@ -106,6 +106,9 @@ disallow_untyped_defs = False [mypy-tests.handlers.test_user_directory] disallow_untyped_defs = True +[mypy-tests.push.test_bulk_push_rule_evaluator] +disallow_untyped_defs = True + [mypy-tests.test_server] disallow_untyped_defs = True diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 4270438918..998354648f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -289,11 +289,18 @@ class BulkPushRuleEvaluator: if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id + # It's possible that old room versions have non-integer power levels (floats or + # strings). Workaround this by explicitly converting to int. + notification_levels = power_levels.get("notifications", {}) + if not event.room_version.msc3667_int_only_power_levels: + for user_id, level in notification_levels.items(): + notification_levels[user_id] = int(level) + evaluator = PushRuleEvaluator( _flatten_dict(event), room_member_count, sender_power_level, - power_levels.get("notifications", {}), + notification_levels, relations, self._relations_match_enabled, ) diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py new file mode 100644 index 0000000000..675d7df2ac --- /dev/null +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -0,0 +1,74 @@ +from unittest.mock import patch + +from synapse.api.room_versions import RoomVersions +from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator +from synapse.rest import admin +from synapse.rest.client import login, register, room +from synapse.types import create_requester + +from tests import unittest + + +class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + register.register_servlets, + ] + + def test_action_for_event_by_user_handles_noninteger_power_levels(self) -> None: + """We should convert floats and strings to integers before passing to Rust. + + Reproduces #14060. + + A lack of validation: the gift that keeps on giving. + """ + # Create a new user and room. + alice = self.register_user("alice", "pass") + token = self.login(alice, "pass") + + room_id = self.helper.create_room_as( + alice, room_version=RoomVersions.V9.identifier, tok=token + ) + + # Alter the power levels in that room to include stringy and floaty levels. + # We need to suppress the validation logic or else it will reject these dodgy + # values. (Presumably this validation was not always present.) + event_creation_handler = self.hs.get_event_creation_handler() + requester = create_requester(alice) + with patch("synapse.events.validator.validate_canonicaljson"), patch( + "synapse.events.validator.jsonschema.validate" + ): + self.helper.send_state( + room_id, + "m.room.power_levels", + { + "users": {alice: "100"}, # stringy + "notifications": {"room": 100.0}, # float + }, + token, + state_key="", + ) + + # Create a new message event, and try to evaluate it under the dodgy + # power level event. + event, context = self.get_success( + event_creation_handler.create_event( + requester, + { + "type": "m.room.message", + "room_id": room_id, + "content": { + "msgtype": "m.text", + "body": "helo", + }, + "sender": alice, + }, + ) + ) + + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + # should not raise + self.get_success(bulk_evaluator.action_for_event_by_user(event, context)) -- cgit 1.5.1 From 09be8ab5f9d54fa1a577d8b0028abf8acc28f30d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Oct 2022 06:26:39 -0400 Subject: Remove the experimental implementation of MSC3772. (#14094) MSC3772 has been abandoned. --- changelog.d/14094.removal | 1 + rust/src/push/base_rules.rs | 13 ---- rust/src/push/evaluator.rs | 105 +--------------------------- rust/src/push/mod.rs | 44 +++--------- stubs/synapse/synapse_rust/push.pyi | 6 +- synapse/config/experimental.py | 2 - synapse/push/bulk_push_rule_evaluator.py | 64 +---------------- synapse/storage/databases/main/cache.py | 3 - synapse/storage/databases/main/events.py | 5 -- synapse/storage/databases/main/push_rule.py | 15 ++-- synapse/storage/databases/main/relations.py | 53 -------------- tests/push/test_push_rule_evaluator.py | 76 +------------------- 12 files changed, 22 insertions(+), 365 deletions(-) create mode 100644 changelog.d/14094.removal (limited to 'tests/push') diff --git a/changelog.d/14094.removal b/changelog.d/14094.removal new file mode 100644 index 0000000000..6ef03b1a0f --- /dev/null +++ b/changelog.d/14094.removal @@ -0,0 +1 @@ +Remove the experimental implementation of [MSC3772](https://github.com/matrix-org/matrix-spec-proposals/pull/3772). diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 2a09cf99ae..63240cacfc 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -257,19 +257,6 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, - PushRule { - rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3772.thread_reply"), - priority_class: 1, - conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch { - rel_type: Cow::Borrowed("m.thread"), - event_type_pattern: None, - sender: None, - sender_type: Some(Cow::Borrowed("user_id")), - })]), - actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), - default: true, - default_enabled: true, - }, PushRule { rule_id: Cow::Borrowed("global/underride/.m.rule.message"), priority_class: 1, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index efe88ec76e..0365dd01dc 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - borrow::Cow, - collections::{BTreeMap, BTreeSet}, -}; +use std::collections::BTreeMap; use anyhow::{Context, Error}; use lazy_static::lazy_static; @@ -49,13 +46,6 @@ pub struct PushRuleEvaluator { /// The `notifications` section of the current power levels in the room. notification_power_levels: BTreeMap, - /// The relations related to the event as a mapping from relation type to - /// set of sender/event type 2-tuples. - relations: BTreeMap>, - - /// Is running "relation" conditions enabled? - relation_match_enabled: bool, - /// The power level of the sender of the event, or None if event is an /// outlier. sender_power_level: Option, @@ -70,8 +60,6 @@ impl PushRuleEvaluator { room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, - relations: BTreeMap>, - relation_match_enabled: bool, ) -> Result { let body = flattened_keys .get("content.body") @@ -83,8 +71,6 @@ impl PushRuleEvaluator { body, room_member_count, notification_power_levels, - relations, - relation_match_enabled, sender_power_level, }) } @@ -203,89 +189,11 @@ impl PushRuleEvaluator { false } } - KnownCondition::RelationMatch { - rel_type, - event_type_pattern, - sender, - sender_type, - } => { - self.match_relations(rel_type, sender, sender_type, user_id, event_type_pattern)? - } }; Ok(result) } - /// Evaluates a relation condition. - fn match_relations( - &self, - rel_type: &str, - sender: &Option>, - sender_type: &Option>, - user_id: Option<&str>, - event_type_pattern: &Option>, - ) -> Result { - // First check if relation matching is enabled... - if !self.relation_match_enabled { - return Ok(false); - } - - // ... and if there are any relations to match against. - let relations = if let Some(relations) = self.relations.get(rel_type) { - relations - } else { - return Ok(false); - }; - - // Extract the sender pattern from the condition - let sender_pattern = if let Some(sender) = sender { - Some(sender.as_ref()) - } else if let Some(sender_type) = sender_type { - if sender_type == "user_id" { - if let Some(user_id) = user_id { - Some(user_id) - } else { - return Ok(false); - } - } else { - warn!("Unrecognized sender_type: {sender_type}"); - return Ok(false); - } - } else { - None - }; - - let mut sender_compiled_pattern = if let Some(pattern) = sender_pattern { - Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) - } else { - None - }; - - let mut type_compiled_pattern = if let Some(pattern) = event_type_pattern { - Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) - } else { - None - }; - - for (relation_sender, event_type) in relations { - if let Some(pattern) = &mut sender_compiled_pattern { - if !pattern.is_match(relation_sender)? { - continue; - } - } - - if let Some(pattern) = &mut type_compiled_pattern { - if !pattern.is_match(event_type)? { - continue; - } - } - - return Ok(true); - } - - Ok(false) - } - /// Evaluates a `event_match` condition. fn match_event_match( &self, @@ -359,15 +267,8 @@ impl PushRuleEvaluator { fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); - let evaluator = PushRuleEvaluator::py_new( - flattened_keys, - 10, - Some(0), - BTreeMap::new(), - BTreeMap::new(), - true, - ) - .unwrap(); + let evaluator = + PushRuleEvaluator::py_new(flattened_keys, 10, Some(0), BTreeMap::new()).unwrap(); let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); assert_eq!(result.len(), 3); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 208b9c0d73..0dabfab8b8 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -275,16 +275,6 @@ pub enum KnownCondition { SenderNotificationPermission { key: Cow<'static, str>, }, - #[serde(rename = "org.matrix.msc3772.relation_match")] - RelationMatch { - rel_type: Cow<'static, str>, - #[serde(skip_serializing_if = "Option::is_none", rename = "type")] - event_type_pattern: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - sender: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - sender_type: Option>, - }, } impl IntoPy for Condition { @@ -401,21 +391,15 @@ impl PushRules { pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, - msc3772_enabled: bool, } #[pymethods] impl FilteredPushRules { #[new] - pub fn py_new( - push_rules: PushRules, - enabled_map: BTreeMap, - msc3772_enabled: bool, - ) -> Self { + pub fn py_new(push_rules: PushRules, enabled_map: BTreeMap) -> Self { Self { push_rules, enabled_map, - msc3772_enabled, } } @@ -430,25 +414,13 @@ impl FilteredPushRules { /// Iterates over all the rules and their enabled state, including base /// rules, in the order they should be executed in. fn iter(&self) -> impl Iterator { - self.push_rules - .iter() - .filter(|rule| { - // Ignore disabled experimental push rules - if !self.msc3772_enabled - && rule.rule_id == "global/underride/.org.matrix.msc3772.thread_reply" - { - return false; - } - - true - }) - .map(|r| { - let enabled = *self - .enabled_map - .get(&*r.rule_id) - .unwrap_or(&r.default_enabled); - (r, enabled) - }) + self.push_rules.iter().map(|r| { + let enabled = *self + .enabled_map + .get(&*r.rule_id) + .unwrap_or(&r.default_enabled); + (r, enabled) + }) } } diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 5900e61450..f2a61df660 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -25,9 +25,7 @@ class PushRules: def rules(self) -> Collection[PushRule]: ... class FilteredPushRules: - def __init__( - self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3772_enabled: bool - ): ... + def __init__(self, push_rules: PushRules, enabled_map: Dict[str, bool]): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def get_base_rule_ids() -> Collection[str]: ... @@ -39,8 +37,6 @@ class PushRuleEvaluator: room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], - relations: Mapping[str, Set[Tuple[str, str]]], - relation_match_enabled: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index e00cb7096c..f44655516e 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -95,8 +95,6 @@ class ExperimentalConfig(Config): # MSC2815 (allow room moderators to view redacted event content) self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False) - # MSC3772: A push rule for mutual relations. - self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index eced182fd5..8d94aeaa32 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from typing import ( TYPE_CHECKING, Any, Collection, Dict, - Iterable, List, Mapping, Optional, - Set, Tuple, Union, ) @@ -38,7 +35,7 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.state import StateFilter -from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator +from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state @@ -117,9 +114,6 @@ class BulkPushRuleEvaluator: resizable=False, ) - # Whether to support MSC3772 is supported. - self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled - async def _get_rules_for_event( self, event: EventBase, @@ -200,51 +194,6 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - async def _get_mutual_relations( - self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. - - Args: - parent_id: The event ID which is targeted by relations. - rules: The push rules which will be processed for this event. - - Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type - """ - - # If the experimental feature is not enabled, skip fetching relations. - if not self._relations_match_enabled: - return {} - - # Pre-filter to figure out which relation types are interesting. - rel_types = set() - for rule, enabled in rules: - if not enabled: - continue - - for condition in rule.conditions: - if condition["kind"] != "org.matrix.msc3772.relation_match": - continue - - # rel_type is required. - rel_type = condition.get("rel_type") - if rel_type: - rel_types.add(rel_type) - - # If no valid rules were found, no mutual relations. - if not rel_types: - return {} - - # If any valid rules were found, fetch the mutual relations. - return await self.store.get_mutual_event_relations(parent_id, rel_types) - @measure_func("action_for_event_by_user") async def action_for_event_by_user( self, event: EventBase, context: EventContext @@ -276,16 +225,11 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) + # Find the event's thread ID. relation = relation_from_event(event) - # If the event does not have a relation, then cannot have any mutual - # relations or thread ID. - relations = {} + # If the event does not have a relation, then it cannot have a thread ID. thread_id = MAIN_TIMELINE if relation: - relations = await self._get_mutual_relations( - relation.parent_id, - itertools.chain(*(r.rules() for r in rules_by_user.values())), - ) # Recursively attempt to find the thread this event relates to. if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id @@ -306,8 +250,6 @@ class BulkPushRuleEvaluator: room_member_count, sender_power_level, notification_levels, - relations, - self._relations_match_enabled, ) users = rules_by_user.keys() diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 3b8ed1f7ee..a9f25a5904 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,9 +259,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) - self._attempt_to_invalidate_cache( - "get_mutual_event_relations_for_rel_type", (relates_to,) - ) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 3e15827986..060fe71454 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2024,11 +2024,6 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) - self.store._invalidate_cache_and_stream( - txn, - self.store.get_mutual_event_relations_for_rel_type, - (redacted_relates_to,), - ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8295322b0e..51416b2236 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -29,7 +29,6 @@ from typing import ( ) from synapse.api.errors import StoreError -from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -63,9 +62,7 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], - enabled_map: Dict[str, bool], - experimental_config: ExperimentalConfig, + rawrules: List[JsonDict], enabled_map: Dict[str, bool] ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. @@ -83,9 +80,7 @@ def _load_rules( push_rules = PushRules(ruleslist) - filtered_rules = FilteredPushRules( - push_rules, enabled_map, msc3772_enabled=experimental_config.msc3772_enabled - ) + filtered_rules = FilteredPushRules(push_rules, enabled_map) return filtered_rules @@ -165,7 +160,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map, self.hs.config.experimental) + return _load_rules(rows, enabled_map) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -224,9 +219,7 @@ class PushRulesWorkerStore( results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): - results[user_id] = _load_rules( - rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental - ) + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) return results diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 116abef9de..6b7eec4bf2 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -776,59 +776,6 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - @cached(iterable=True) - async def get_mutual_event_relations_for_rel_type( - self, event_id: str, relation_type: str - ) -> Set[Tuple[str, str]]: - raise NotImplementedError() - - @cachedList( - cached_method_name="get_mutual_event_relations_for_rel_type", - list_name="relation_types", - ) - async def get_mutual_event_relations( - self, event_id: str, relation_types: Collection[str] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. - - Args: - event_id: The event ID which is targeted by relations. - relation_types: The relation types to check for mutual relations. - - Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type - """ - rel_type_sql, rel_type_args = make_in_list_sql_clause( - self.database_engine, "relation_type", relation_types - ) - - sql = f""" - SELECT DISTINCT relation_type, sender, type FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND {rel_type_sql} - """ - - def _get_event_relations( - txn: LoggingTransaction, - ) -> Dict[str, Set[Tuple[str, str]]]: - txn.execute(sql, [event_id] + rel_type_args) - result: Dict[str, Set[Tuple[str, str]]] = { - rel_type: set() for rel_type in relation_types - } - for rel_type, sender, type in txn.fetchall(): - result[rel_type].add((sender, type)) - return result - - return await self.db_pool.runInteraction( - "get_event_relations", _get_event_relations - ) - @cached() async def get_thread_id(self, event_id: str) -> Optional[str]: """ diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 8804f0e0d3..decf619466 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Set, Tuple, Union +from typing import Dict, Optional, Union import frozendict @@ -38,12 +38,7 @@ from tests.test_utils.event_injection import create_event, inject_member_event class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator( - self, - content: JsonDict, - relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None, - relations_match_enabled: bool = False, - ) -> PushRuleEvaluator: + def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator: event = FrozenEvent( { "event_id": "$event_id", @@ -63,8 +58,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count, sender_power_level, power_levels.get("notifications", {}), - relations or {}, - relations_match_enabled, ) def test_display_name(self) -> None: @@ -299,71 +292,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): {"sound": "default", "highlight": True}, ) - def test_relation_match(self) -> None: - """Test the relation_match push rule kind.""" - - # Check if the experimental feature is disabled. - evaluator = self._get_evaluator( - {}, {"m.annotation": {("@user:test", "m.reaction")}} - ) - - # A push rule evaluator with the experimental rule enabled. - evaluator = self._get_evaluator( - {}, {"m.annotation": {("@user:test", "m.reaction")}}, True - ) - - # Check just relation type. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check relation type and sender. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@user:test", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@other:test", - } - self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - - # Check relation type and event type. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "type": "m.reaction", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check just sender, this fails since rel_type is required. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "sender": "@user:test", - } - self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - - # Check sender glob. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@*:test", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check event type glob. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "event_type": "*.reaction", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator""" -- cgit 1.5.1 From b7a7ff6ee39da4981dcfdce61bf8ac4735e3d047 Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 21 Oct 2022 10:46:22 -0700 Subject: Add initial power level event to batch of bulk persisted events when creating a new room. (#14228) --- changelog.d/14228.misc | 1 + synapse/handlers/federation.py | 4 +- synapse/handlers/federation_event.py | 4 +- synapse/handlers/message.py | 14 ++---- synapse/handlers/room.py | 39 ++++----------- synapse/push/bulk_push_rule_evaluator.py | 74 ++++++++++++++++++++++++----- tests/push/test_bulk_push_rule_evaluator.py | 2 +- tests/replication/_base.py | 2 +- 8 files changed, 82 insertions(+), 58 deletions(-) create mode 100644 changelog.d/14228.misc (limited to 'tests/push') diff --git a/changelog.d/14228.misc b/changelog.d/14228.misc new file mode 100644 index 0000000000..14fe31a8bc --- /dev/null +++ b/changelog.d/14228.misc @@ -0,0 +1 @@ +Add initial power level event to batch of bulk persisted events when creating a new room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 275a37a575..4fbc79a6cb 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1017,7 +1017,9 @@ class FederationHandler: context = EventContext.for_outlier(self._storage_controllers) - await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] + ) try: await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 06e41b5cc0..7da6316a82 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2171,8 +2171,8 @@ class FederationEventHandler: min_depth, ) else: - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] ) try: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 15b828dd74..468900a07f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1433,17 +1433,9 @@ class EventCreationHandler: a room that has been un-partial stated. """ - for event, context in events_and_context: - # Skip push notification actions for historical messages - # because we don't want to notify people about old history back in time. - # The historical messages also do not have the proper `context.current_state_ids` - # and `state_groups` because they have `prev_events` that aren't persisted yet - # (historical messages persisted in reverse-chronological order). - if not event.internal_metadata.is_historical(): - with opentracing.start_active_span("calculate_push_actions"): - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context - ) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + events_and_context + ) try: # If we're a worker we need to hit out to the master. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 638f54051a..cc1e5c8f97 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1055,9 +1055,6 @@ class RoomCreationHandler: event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} depth = 1 - # the last event sent/persisted to the db - last_sent_event_id: Optional[str] = None - # the most recently created event prev_event: List[str] = [] # a map of event types, state keys -> event_ids. We collect these mappings this as events are @@ -1102,26 +1099,6 @@ class RoomCreationHandler: return new_event, new_context - async def send( - event: EventBase, - context: synapse.events.snapshot.EventContext, - creator: Requester, - ) -> int: - nonlocal last_sent_event_id - - ev = await self.event_creation_handler.handle_new_client_event( - requester=creator, - events_and_context=[(event, context)], - ratelimit=False, - ignore_shadow_ban=True, - ) - - last_sent_event_id = ev.event_id - - # we know it was persisted, so must have a stream ordering - assert ev.internal_metadata.stream_ordering - return ev.internal_metadata.stream_ordering - try: config = self._presets_dict[preset_config] except KeyError: @@ -1135,10 +1112,14 @@ class RoomCreationHandler: ) logger.debug("Sending %s in new room", EventTypes.Member) - await send(creation_event, creation_context, creator) + ev = await self.event_creation_handler.handle_new_client_event( + requester=creator, + events_and_context=[(creation_event, creation_context)], + ratelimit=False, + ignore_shadow_ban=True, + ) + last_sent_event_id = ev.event_id - # Room create event must exist at this point - assert last_sent_event_id is not None member_event_id, _ = await self.room_member_handler.update_membership( creator, creator.user, @@ -1157,6 +1138,7 @@ class RoomCreationHandler: depth += 1 state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id + events_to_send = [] # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) @@ -1165,7 +1147,7 @@ class RoomCreationHandler: EventTypes.PowerLevels, pl_content, False ) current_state_group = power_context._state_group - await send(power_event, power_context, creator) + events_to_send.append((power_event, power_context)) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1214,9 +1196,8 @@ class RoomCreationHandler: False, ) current_state_group = pl_context._state_group - await send(pl_event, pl_context, creator) + events_to_send.append((pl_event, pl_context)) - events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: room_alias_event, room_alias_context = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index a75386f6a0..d7795a9080 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -165,8 +165,21 @@ class BulkPushRuleEvaluator: return rules_by_user async def _get_power_levels_and_sender_level( - self, event: EventBase, context: EventContext + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], ) -> Tuple[dict, Optional[int]]: + """ + Given an event and an event context, get the power level event relevant to the event + and the power level of the sender of the event. + Args: + event: event to check + context: context of event to check + event_id_to_event: a mapping of event_id to event for a set of events being + batch persisted. This is needed as the sought-after power level event may + be in this batch rather than the DB + """ # There are no power levels and sender levels possible to get from outlier if event.internal_metadata.is_outlier(): return {}, None @@ -177,15 +190,26 @@ class BulkPushRuleEvaluator: ) pl_event_id = prev_state_ids.get(POWER_KEY) + # fastpath: if there's a power level event, that's all we need, and + # not having a power level event is an extreme edge case if pl_event_id: - # fastpath: if there's a power level event, that's all we need, and - # not having a power level event is an extreme edge case - auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} + # Get the power level event from the batch, or fall back to the database. + pl_event = event_id_to_event.get(pl_event_id) + if pl_event: + auth_events = {POWER_KEY: pl_event} + else: + auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} else: auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events_dict = await self.store.get_events(auth_events_ids) + # Some needed auth events might be in the batch, combine them with those + # fetched from the database. + for auth_event_id in auth_events_ids: + auth_event = event_id_to_event.get(auth_event_id) + if auth_event: + auth_events_dict[auth_event_id] = auth_event auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -194,16 +218,38 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - @measure_func("action_for_event_by_user") - async def action_for_event_by_user( - self, event: EventBase, context: EventContext + async def action_for_events_by_user( + self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: - """Given an event and context, evaluate the push rules, check if the message - should increment the unread count, and insert the results into the - event_push_actions_staging table. + """Given a list of events and their associated contexts, evaluate the push rules + for each event, check if the message should increment the unread count, and + insert the results into the event_push_actions_staging table. """ - if not event.internal_metadata.is_notifiable(): - # Push rules for events that aren't notifiable can't be processed by this + # For batched events the power level events may not have been persisted yet, + # so we pass in the batched events. Thus if the event cannot be found in the + # database we can check in the batch. + event_id_to_event = {e.event_id: e for e, _ in events_and_context} + for event, context in events_and_context: + await self._action_for_event_by_user(event, context, event_id_to_event) + + @measure_func("action_for_event_by_user") + async def _action_for_event_by_user( + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], + ) -> None: + + if ( + not event.internal_metadata.is_notifiable() + or event.internal_metadata.is_historical() + ): + # Push rules for events that aren't notifiable can't be processed by this and + # we want to skip push notification actions for historical messages + # because we don't want to notify people about old history back in time. + # The historical messages also do not have the proper `context.current_state_ids` + # and `state_groups` because they have `prev_events` that aren't persisted yet + # (historical messages persisted in reverse-chronological order). return # Disable counting as unread unless the experimental configuration is @@ -223,7 +269,9 @@ class BulkPushRuleEvaluator: ( power_levels, sender_power_level, - ) = await self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level( + event, context, event_id_to_event + ) # Find the event's thread ID. relation = relation_from_event(event) diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 675d7df2ac..594e7937a8 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -71,4 +71,4 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise - self.get_success(bulk_evaluator.action_for_event_by_user(event, context)) + self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index ce53f808db..121f3d8d65 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -371,7 +371,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config=worker_hs.config.server.listeners[0], resource=resource, server_version_string="1", - max_request_body_size=4096, + max_request_body_size=8192, reactor=self.reactor, ) -- cgit 1.5.1 From 2d0ba3f89aaf9545d81c4027500e543ec70b68a6 Mon Sep 17 00:00:00 2001 From: "DeepBlueV7.X" Date: Tue, 25 Oct 2022 13:38:01 +0000 Subject: Implementation for MSC3664: Pushrules for relations (#11804) --- changelog.d/11804.feature | 1 + rust/src/push/base_rules.rs | 17 +++ rust/src/push/evaluator.rs | 99 ++++++++++++- rust/src/push/mod.rs | 61 ++++++-- stubs/synapse/synapse_rust/push.pyi | 6 +- synapse/config/experimental.py | 3 + synapse/push/bulk_push_rule_evaluator.py | 49 ++++++- synapse/rest/client/capabilities.py | 5 + synapse/storage/databases/main/push_rule.py | 15 +- tests/push/test_push_rule_evaluator.py | 215 +++++++++++++++++++++++++++- 10 files changed, 454 insertions(+), 17 deletions(-) create mode 100644 changelog.d/11804.feature (limited to 'tests/push') diff --git a/changelog.d/11804.feature b/changelog.d/11804.feature new file mode 100644 index 0000000000..6420393541 --- /dev/null +++ b/changelog.d/11804.feature @@ -0,0 +1 @@ +Implement [MSC3664](https://github.com/matrix-org/matrix-doc/pull/3664). Contributed by Nico. diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 63240cacfc..49802fa4eb 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -25,6 +25,7 @@ use crate::push::Action; use crate::push::Condition; use crate::push::EventMatchCondition; use crate::push::PushRule; +use crate::push::RelatedEventMatchCondition; use crate::push::SetTweak; use crate::push::TweakValue; @@ -114,6 +115,22 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed("global/override/.im.nheko.msc3664.reply"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelatedEventMatch( + RelatedEventMatchCondition { + key: Some(Cow::Borrowed("sender")), + pattern: None, + pattern_type: Some(Cow::Borrowed("user_id")), + rel_type: Cow::Borrowed("m.in_reply_to"), + include_fallbacks: None, + }, + ))]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"), priority_class: 5, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 0365dd01dc..cedd42c54d 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -23,6 +23,7 @@ use regex::Regex; use super::{ utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType}, Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition, + RelatedEventMatchCondition, }; lazy_static! { @@ -49,6 +50,13 @@ pub struct PushRuleEvaluator { /// The power level of the sender of the event, or None if event is an /// outlier. sender_power_level: Option, + + /// The related events, indexed by relation type. Flattened in the same manner as + /// `flattened_keys`. + related_events_flattened: BTreeMap>, + + /// If msc3664, push rules for related events, is enabled. + related_event_match_enabled: bool, } #[pymethods] @@ -60,6 +68,8 @@ impl PushRuleEvaluator { room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, + related_events_flattened: BTreeMap>, + related_event_match_enabled: bool, ) -> Result { let body = flattened_keys .get("content.body") @@ -72,6 +82,8 @@ impl PushRuleEvaluator { room_member_count, notification_power_levels, sender_power_level, + related_events_flattened, + related_event_match_enabled, }) } @@ -156,6 +168,9 @@ impl PushRuleEvaluator { KnownCondition::EventMatch(event_match) => { self.match_event_match(event_match, user_id)? } + KnownCondition::RelatedEventMatch(event_match) => { + self.match_related_event_match(event_match, user_id)? + } KnownCondition::ContainsDisplayName => { if let Some(dn) = display_name { if !dn.is_empty() { @@ -239,6 +254,79 @@ impl PushRuleEvaluator { compiled_pattern.is_match(haystack) } + /// Evaluates a `related_event_match` condition. (MSC3664) + fn match_related_event_match( + &self, + event_match: &RelatedEventMatchCondition, + user_id: Option<&str>, + ) -> Result { + // First check if related event matching is enabled... + if !self.related_event_match_enabled { + return Ok(false); + } + + // get the related event, fail if there is none. + let event = if let Some(event) = self.related_events_flattened.get(&*event_match.rel_type) { + event + } else { + return Ok(false); + }; + + // If we are not matching fallbacks, don't match if our special key indicating this is a + // fallback relation is not present. + if !event_match.include_fallbacks.unwrap_or(false) + && event.contains_key("im.vector.is_falling_back") + { + return Ok(false); + } + + // if we have no key, accept the event as matching, if it existed without matching any + // fields. + let key = if let Some(key) = &event_match.key { + key + } else { + return Ok(true); + }; + + let pattern = if let Some(pattern) = &event_match.pattern { + pattern + } else if let Some(pattern_type) = &event_match.pattern_type { + // The `pattern_type` can either be "user_id" or "user_localpart", + // either way if we don't have a `user_id` then the condition can't + // match. + let user_id = if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + }; + + match &**pattern_type { + "user_id" => user_id, + "user_localpart" => get_localpart_from_id(user_id)?, + _ => return Ok(false), + } + } else { + return Ok(false); + }; + + let haystack = if let Some(haystack) = event.get(&**key) { + haystack + } else { + return Ok(false); + }; + + // For the content.body we match against "words", but for everything + // else we match against the entire value. + let match_type = if key == "content.body" { + GlobMatchType::Word + } else { + GlobMatchType::Whole + }; + + let mut compiled_pattern = get_glob_matcher(pattern, match_type)?; + compiled_pattern.is_match(haystack) + } + /// Match the member count against an 'is' condition /// The `is` condition can be things like '>2', '==3' or even just '4'. fn match_member_count(&self, is: &str) -> Result { @@ -267,8 +355,15 @@ impl PushRuleEvaluator { fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); - let evaluator = - PushRuleEvaluator::py_new(flattened_keys, 10, Some(0), BTreeMap::new()).unwrap(); + let evaluator = PushRuleEvaluator::py_new( + flattened_keys, + 10, + Some(0), + BTreeMap::new(), + BTreeMap::new(), + true, + ) + .unwrap(); let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); assert_eq!(result.len(), 3); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 0dabfab8b8..d57800aa4a 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -267,6 +267,8 @@ pub enum Condition { #[serde(tag = "kind")] pub enum KnownCondition { EventMatch(EventMatchCondition), + #[serde(rename = "im.nheko.msc3664.related_event_match")] + RelatedEventMatch(RelatedEventMatchCondition), ContainsDisplayName, RoomMemberCount { #[serde(skip_serializing_if = "Option::is_none")] @@ -299,6 +301,20 @@ pub struct EventMatchCondition { pub pattern_type: Option>, } +/// The body of a [`Condition::RelatedEventMatch`] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RelatedEventMatchCondition { + #[serde(skip_serializing_if = "Option::is_none")] + pub key: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub pattern: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub pattern_type: Option>, + pub rel_type: Cow<'static, str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub include_fallbacks: Option, +} + /// The collection of push rules for a user. #[derive(Debug, Clone, Default)] #[pyclass(frozen)] @@ -391,15 +407,21 @@ impl PushRules { pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, + msc3664_enabled: bool, } #[pymethods] impl FilteredPushRules { #[new] - pub fn py_new(push_rules: PushRules, enabled_map: BTreeMap) -> Self { + pub fn py_new( + push_rules: PushRules, + enabled_map: BTreeMap, + msc3664_enabled: bool, + ) -> Self { Self { push_rules, enabled_map, + msc3664_enabled, } } @@ -414,13 +436,25 @@ impl FilteredPushRules { /// Iterates over all the rules and their enabled state, including base /// rules, in the order they should be executed in. fn iter(&self) -> impl Iterator { - self.push_rules.iter().map(|r| { - let enabled = *self - .enabled_map - .get(&*r.rule_id) - .unwrap_or(&r.default_enabled); - (r, enabled) - }) + self.push_rules + .iter() + .filter(|rule| { + // Ignore disabled experimental push rules + if !self.msc3664_enabled + && rule.rule_id == "global/override/.im.nheko.msc3664.reply" + { + return false; + } + + true + }) + .map(|r| { + let enabled = *self + .enabled_map + .get(&*r.rule_id) + .unwrap_or(&r.default_enabled); + (r, enabled) + }) } } @@ -446,6 +480,17 @@ fn test_deserialize_condition() { let _: Condition = serde_json::from_str(json).unwrap(); } +#[test] +fn test_deserialize_unstable_msc3664_condition() { + let json = r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern":"coffee","rel_type":"m.in_reply_to"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::RelatedEventMatch(_)) + )); +} + #[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index f2a61df660..f3b6d6c933 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -25,7 +25,9 @@ class PushRules: def rules(self) -> Collection[PushRule]: ... class FilteredPushRules: - def __init__(self, push_rules: PushRules, enabled_map: Dict[str, bool]): ... + def __init__( + self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3664_enabled: bool + ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def get_base_rule_ids() -> Collection[str]: ... @@ -37,6 +39,8 @@ class PushRuleEvaluator: room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], + related_events_flattened: Mapping[str, Mapping[str, str]], + related_event_match_enabled: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 4009add01d..d9bdd66d55 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -98,6 +98,9 @@ class ExperimentalConfig(Config): # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) + # MSC3664: Pushrules to match on related events + self.msc3664_enabled: bool = experimental.get("msc3664_enabled", False) + # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d7795a9080..75b7e126ca 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -45,7 +45,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) - push_rules_invalidation_counter = Counter( "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "" ) @@ -107,6 +106,8 @@ class BulkPushRuleEvaluator: self.clock = hs.get_clock() self._event_auth_handler = hs.get_event_auth_handler() + self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled + self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", @@ -218,6 +219,48 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level + async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]: + """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation + + Returns: + Mapping of relation type to flattened events. + """ + related_events: Dict[str, Dict[str, str]] = {} + if self._related_event_match_enabled: + related_event_id = event.content.get("m.relates_to", {}).get("event_id") + relation_type = event.content.get("m.relates_to", {}).get("rel_type") + if related_event_id is not None and relation_type is not None: + related_event = await self.store.get_event( + related_event_id, allow_none=True + ) + if related_event is not None: + related_events[relation_type] = _flatten_dict(related_event) + + reply_event_id = ( + event.content.get("m.relates_to", {}) + .get("m.in_reply_to", {}) + .get("event_id") + ) + + # convert replies to pseudo relations + if reply_event_id is not None: + related_event = await self.store.get_event( + reply_event_id, allow_none=True + ) + + if related_event is not None: + related_events["m.in_reply_to"] = _flatten_dict(related_event) + + # indicate that this is from a fallback relation. + if relation_type == "m.thread" and event.content.get( + "m.relates_to", {} + ).get("is_falling_back", False): + related_events["m.in_reply_to"][ + "im.vector.is_falling_back" + ] = "" + + return related_events + async def action_for_events_by_user( self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: @@ -286,6 +329,8 @@ class BulkPushRuleEvaluator: # the parent is part of a thread. thread_id = await self.store.get_thread_id(relation.parent_id) + related_events = await self._related_events(event) + # It's possible that old room versions have non-integer power levels (floats or # strings). Workaround this by explicitly converting to int. notification_levels = power_levels.get("notifications", {}) @@ -298,6 +343,8 @@ class BulkPushRuleEvaluator: room_member_count, sender_power_level, notification_levels, + related_events, + self._related_event_match_enabled, ) users = rules_by_user.keys() diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 4237071c61..e84dde31b1 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -77,6 +77,11 @@ class CapabilitiesRestServlet(RestServlet): "enabled": True, } + if self.config.experimental.msc3664_enabled: + response["capabilities"]["im.nheko.msc3664.related_event_match"] = { + "enabled": self.config.experimental.msc3664_enabled, + } + return HTTPStatus.OK, response diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 51416b2236..b6c15f29f8 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -29,6 +29,7 @@ from typing import ( ) from synapse.api.errors import StoreError +from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -62,7 +63,9 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], enabled_map: Dict[str, bool] + rawrules: List[JsonDict], + enabled_map: Dict[str, bool], + experimental_config: ExperimentalConfig, ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. @@ -80,7 +83,9 @@ def _load_rules( push_rules = PushRules(ruleslist) - filtered_rules = FilteredPushRules(push_rules, enabled_map) + filtered_rules = FilteredPushRules( + push_rules, enabled_map, msc3664_enabled=experimental_config.msc3664_enabled + ) return filtered_rules @@ -160,7 +165,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map) + return _load_rules(rows, enabled_map, self.hs.config.experimental) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -219,7 +224,9 @@ class PushRulesWorkerStore( results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): - results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) + results[user_id] = _load_rules( + rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental + ) return results diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index decf619466..fe7c145840 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -38,7 +38,9 @@ from tests.test_utils.event_injection import create_event, inject_member_event class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator: + def _get_evaluator( + self, content: JsonDict, related_events=None + ) -> PushRuleEvaluator: event = FrozenEvent( { "event_id": "$event_id", @@ -58,6 +60,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count, sender_power_level, power_levels.get("notifications", {}), + {} if related_events is None else related_events, + True, ) def test_display_name(self) -> None: @@ -292,6 +296,215 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): {"sound": "default", "highlight": True}, ) + def test_related_event_match(self): + evaluator = self._get_evaluator( + { + "m.relates_to": { + "event_id": "$parent_event_id", + "key": "😀", + "rel_type": "m.annotation", + "m.in_reply_to": { + "event_id": "$parent_event_id", + }, + } + }, + { + "m.in_reply_to": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + "m.annotation": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + }, + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@user:test", + }, + "@other_user:test", + "display_name", + ) + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.annotation", + "pattern": "@other_user:test", + }, + "@other_user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.replace", + }, + "@other_user:test", + "display_name", + ) + ) + + def test_related_event_match_with_fallback(self): + evaluator = self._get_evaluator( + { + "m.relates_to": { + "event_id": "$parent_event_id", + "key": "😀", + "rel_type": "m.thread", + "is_falling_back": True, + "m.in_reply_to": { + "event_id": "$parent_event_id", + }, + } + }, + { + "m.in_reply_to": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + "im.vector.is_falling_back": "", + }, + "m.thread": { + "event_id": "$parent_event_id", + "type": "m.room.message", + "sender": "@other_user:test", + "room_id": "!room:test", + "content.msgtype": "m.text", + "content.body": "Original message", + }, + }, + ) + self.assertTrue( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + "include_fallbacks": True, + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + "include_fallbacks": False, + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + + def test_related_event_match_no_related_event(self): + evaluator = self._get_evaluator( + {"msgtype": "m.text", "body": "Message without related event"} + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + "pattern": "@other_user:test", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "key": "sender", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "im.nheko.msc3664.related_event_match", + "rel_type": "m.in_reply_to", + }, + "@user:test", + "display_name", + ) + ) + class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator""" -- cgit 1.5.1