diff options
author | Brendan Abolivier <babolivier@matrix.org> | 2022-04-27 15:55:33 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-27 13:55:33 +0000 |
commit | 5ef673de4f0bf991402ee29235741a91a7cc9b02 (patch) | |
tree | 2668af82d4ea46519d9c401925d53ff07cb95881 | |
parent | Use supervisord to supervise Postgres and Caddy in the Complement image. (#12... (diff) | |
download | synapse-5ef673de4f0bf991402ee29235741a91a7cc9b02.tar.xz |
Add a module API to allow modules to edit push rule actions (#12406)
Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
-rw-r--r-- | changelog.d/12406.feature | 1 | ||||
-rw-r--r-- | synapse/handlers/push_rules.py | 138 | ||||
-rw-r--r-- | synapse/module_api/__init__.py | 64 | ||||
-rw-r--r-- | synapse/module_api/errors.py | 4 | ||||
-rw-r--r-- | synapse/rest/client/push_rule.py | 112 | ||||
-rw-r--r-- | synapse/server.py | 5 | ||||
-rw-r--r-- | synapse/storage/databases/main/push_rule.py | 15 | ||||
-rw-r--r-- | tests/module_api/test_api.py | 84 |
8 files changed, 319 insertions, 104 deletions
diff --git a/changelog.d/12406.feature b/changelog.d/12406.feature new file mode 100644 index 0000000000..e345afdee7 --- /dev/null +++ b/changelog.d/12406.feature @@ -0,0 +1 @@ +Add a module API to allow modules to change actions for existing push rules of local users. diff --git a/synapse/handlers/push_rules.py b/synapse/handlers/push_rules.py new file mode 100644 index 0000000000..2599160bcc --- /dev/null +++ b/synapse/handlers/push_rules.py @@ -0,0 +1,138 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, List, Optional, Union + +import attr + +from synapse.api.errors import SynapseError, UnrecognizedRequestError +from synapse.push.baserules import BASE_RULE_IDS +from synapse.storage.push_rule import RuleNotFoundException +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RuleSpec: + scope: str + template: str + rule_id: str + attr: Optional[str] + + +class PushRulesHandler: + """A class to handle changes in push rules for users.""" + + def __init__(self, hs: "HomeServer"): + self._notifier = hs.get_notifier() + self._main_store = hs.get_datastores().main + + async def set_rule_attr( + self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict] + ) -> None: + """Set an attribute (enabled or actions) on an existing push rule. + + Notifies listeners (e.g. sync handler) of the change. + + Args: + user_id: the user for which to modify the push rule. + spec: the spec of the push rule to modify. + val: the value to change the attribute to. + + Raises: + RuleNotFoundException if the rule being modified doesn't exist. + SynapseError(400) if the value is malformed. + UnrecognizedRequestError if the attribute to change is unknown. + InvalidRuleException if we're trying to change the actions on a rule but + the provided actions aren't compliant with the spec. + """ + if spec.attr not in ("enabled", "actions"): + # for the sake of potential future expansion, shouldn't report + # 404 in the case of an unknown request so check it corresponds to + # a known attribute first. + raise UnrecognizedRequestError() + + namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}" + rule_id = spec.rule_id + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise RuleNotFoundException("Unknown rule %r" % (namespaced_rule_id,)) + if spec.attr == "enabled": + if isinstance(val, dict) and "enabled" in val: + val = val["enabled"] + if not isinstance(val, bool): + # Legacy fallback + # This should *actually* take a dict, but many clients pass + # bools directly, so let's not break them. + raise SynapseError(400, "Value for 'enabled' must be boolean") + await self._main_store.set_push_rule_enabled( + user_id, namespaced_rule_id, val, is_default_rule + ) + elif spec.attr == "actions": + if not isinstance(val, dict): + raise SynapseError(400, "Value must be a dict") + actions = val.get("actions") + if not isinstance(actions, list): + raise SynapseError(400, "Value for 'actions' must be dict") + check_actions(actions) + rule_id = spec.rule_id + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise RuleNotFoundException( + "Unknown rule %r" % (namespaced_rule_id,) + ) + await self._main_store.set_push_rule_actions( + user_id, namespaced_rule_id, actions, is_default_rule + ) + else: + raise UnrecognizedRequestError() + + self.notify_user(user_id) + + def notify_user(self, user_id: str) -> None: + """Notify listeners about a push rule change. + + Args: + user_id: the user ID the change is for. + """ + stream_id = self._main_store.get_max_push_rules_stream_id() + self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) + + +def check_actions(actions: List[Union[str, JsonDict]]) -> None: + """Check if the given actions are spec compliant. + + Args: + actions: the actions to check. + + Raises: + InvalidRuleException if the rules aren't compliant with the spec. + """ + if not isinstance(actions, list): + raise InvalidRuleException("No actions found") + + for a in actions: + if a in ["notify", "dont_notify", "coalesce"]: + pass + elif isinstance(a, dict) and "set_tweak" in a: + pass + else: + raise InvalidRuleException("Unrecognised action %s" % a) + + +class InvalidRuleException(Exception): + pass diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 7e6e23db73..834fe1b62c 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -82,6 +82,7 @@ from synapse.handlers.auth import ( ON_LOGGED_OUT_CALLBACK, AuthHandler, ) +from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.http.client import SimpleHttpClient from synapse.http.server import ( DirectServeHtmlResource, @@ -195,6 +196,7 @@ class ModuleApi: self._clock: Clock = hs.get_clock() self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() + self._push_rules_handler = hs.get_push_rules_handler() self.custom_template_dir = hs.config.server.custom_template_directory try: @@ -1352,6 +1354,68 @@ class ModuleApi: """ await self._store.add_user_bound_threepid(user_id, medium, address, id_server) + def check_push_rule_actions( + self, actions: List[Union[str, Dict[str, str]]] + ) -> None: + """Checks if the given push rule actions are valid according to the Matrix + specification. + + See https://spec.matrix.org/v1.2/client-server-api/#actions for the list of valid + actions. + + Added in Synapse v1.58.0. + + Args: + actions: the actions to check. + + Raises: + synapse.module_api.errors.InvalidRuleException if the actions are invalid. + """ + check_actions(actions) + + async def set_push_rule_action( + self, + user_id: str, + scope: str, + kind: str, + rule_id: str, + actions: List[Union[str, Dict[str, str]]], + ) -> None: + """Changes the actions of an existing push rule for the given user. + + See https://spec.matrix.org/v1.2/client-server-api/#push-rules for more + information about push rules and their syntax. + + Can only be called on the main process. + + Added in Synapse v1.58.0. + + Args: + user_id: the user for which to change the push rule's actions. + scope: the push rule's scope, currently only "global" is allowed. + kind: the push rule's kind. + rule_id: the push rule's identifier. + actions: the actions to run when the rule's conditions match. + + Raises: + RuntimeError if this method is called on a worker or `scope` is invalid. + synapse.module_api.errors.RuleNotFoundException if the rule being modified + can't be found. + synapse.module_api.errors.InvalidRuleException if the actions are invalid. + """ + if self.worker_app is not None: + raise RuntimeError("module tried to change push rule actions on a worker") + + if scope != "global": + raise RuntimeError( + "invalid scope %s, only 'global' is currently allowed" % scope + ) + + spec = RuleSpec(scope, kind, rule_id, "actions") + await self._push_rules_handler.set_rule_attr( + user_id, spec, {"actions": actions} + ) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py index 1db900e41f..e58e0e60fe 100644 --- a/synapse/module_api/errors.py +++ b/synapse/module_api/errors.py @@ -20,10 +20,14 @@ from synapse.api.errors import ( SynapseError, ) from synapse.config._base import ConfigError +from synapse.handlers.push_rules import InvalidRuleException +from synapse.storage.push_rule import RuleNotFoundException __all__ = [ "InvalidClientCredentialsError", "RedirectException", "SynapseError", "ConfigError", + "InvalidRuleException", + "RuleNotFoundException", ] diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index a93f6fd5e0..b98640b14a 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union - -import attr +from typing import TYPE_CHECKING, List, Sequence, Tuple, Union from synapse.api.errors import ( NotFoundError, @@ -22,6 +20,7 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) +from synapse.handlers.push_rules import InvalidRuleException, RuleSpec, check_actions from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -29,7 +28,6 @@ from synapse.http.servlet import ( parse_string, ) from synapse.http.site import SynapseRequest -from synapse.push.baserules import BASE_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest.client._base import client_patterns @@ -40,14 +38,6 @@ if TYPE_CHECKING: from synapse.server import HomeServer -@attr.s(slots=True, frozen=True, auto_attribs=True) -class RuleSpec: - scope: str - template: str - rule_id: str - attr: Optional[str] - - class PushRuleRestServlet(RestServlet): PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True) SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( @@ -60,6 +50,7 @@ class PushRuleRestServlet(RestServlet): self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None + self._push_rules_handler = hs.get_push_rules_handler() async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: @@ -81,8 +72,13 @@ class PushRuleRestServlet(RestServlet): user_id = requester.user.to_string() if spec.attr: - await self.set_rule_attr(user_id, spec, content) - self.notify_user(user_id) + try: + await self._push_rules_handler.set_rule_attr(user_id, spec, content) + except InvalidRuleException as e: + raise SynapseError(400, "Invalid actions: %s" % e) + except RuleNotFoundException: + raise NotFoundError("Unknown rule") + return 200, {} if spec.rule_id.startswith("."): @@ -98,23 +94,23 @@ class PushRuleRestServlet(RestServlet): before = parse_string(request, "before") if before: - before = _namespaced_rule_id(spec, before) + before = f"global/{spec.template}/{before}" after = parse_string(request, "after") if after: - after = _namespaced_rule_id(spec, after) + after = f"global/{spec.template}/{after}" try: await self.store.add_push_rule( user_id=user_id, - rule_id=_namespaced_rule_id_from_spec(spec), + rule_id=f"global/{spec.template}/{spec.rule_id}", priority_class=priority_class, conditions=conditions, actions=actions, before=before, after=after, ) - self.notify_user(user_id) + self._push_rules_handler.notify_user(user_id) except InconsistentRuleException as e: raise SynapseError(400, str(e)) except RuleNotFoundException as e: @@ -133,11 +129,11 @@ class PushRuleRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}" try: await self.store.delete_push_rule(user_id, namespaced_rule_id) - self.notify_user(user_id) + self._push_rules_handler.notify_user(user_id) return 200, {} except StoreError as e: if e.code == 404: @@ -172,55 +168,6 @@ class PushRuleRestServlet(RestServlet): else: raise UnrecognizedRequestError() - def notify_user(self, user_id: str) -> None: - stream_id = self.store.get_max_push_rules_stream_id() - self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) - - async def set_rule_attr( - self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict] - ) -> None: - if spec.attr not in ("enabled", "actions"): - # for the sake of potential future expansion, shouldn't report - # 404 in the case of an unknown request so check it corresponds to - # a known attribute first. - raise UnrecognizedRequestError() - - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec.rule_id - is_default_rule = rule_id.startswith(".") - if is_default_rule: - if namespaced_rule_id not in BASE_RULE_IDS: - raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) - if spec.attr == "enabled": - if isinstance(val, dict) and "enabled" in val: - val = val["enabled"] - if not isinstance(val, bool): - # Legacy fallback - # This should *actually* take a dict, but many clients pass - # bools directly, so let's not break them. - raise SynapseError(400, "Value for 'enabled' must be boolean") - await self.store.set_push_rule_enabled( - user_id, namespaced_rule_id, val, is_default_rule - ) - elif spec.attr == "actions": - if not isinstance(val, dict): - raise SynapseError(400, "Value must be a dict") - actions = val.get("actions") - if not isinstance(actions, list): - raise SynapseError(400, "Value for 'actions' must be dict") - _check_actions(actions) - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec.rule_id - is_default_rule = rule_id.startswith(".") - if is_default_rule: - if namespaced_rule_id not in BASE_RULE_IDS: - raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - await self.store.set_push_rule_actions( - user_id, namespaced_rule_id, actions, is_default_rule - ) - else: - raise UnrecognizedRequestError() - def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec: """Turn a sequence of path components into a rule spec @@ -291,24 +238,11 @@ def _rule_tuple_from_request_object( raise InvalidRuleException("No actions found") actions = req_obj["actions"] - _check_actions(actions) + check_actions(actions) return conditions, actions -def _check_actions(actions: List[Union[str, JsonDict]]) -> None: - if not isinstance(actions, list): - raise InvalidRuleException("No actions found") - - for a in actions: - if a in ["notify", "dont_notify", "coalesce"]: - pass - elif isinstance(a, dict) and "set_tweak" in a: - pass - else: - raise InvalidRuleException("Unrecognised action") - - def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict: if path == []: raise UnrecognizedRequestError( @@ -357,17 +291,5 @@ def _priority_class_from_spec(spec: RuleSpec) -> int: return pc -def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str: - return _namespaced_rule_id(spec, spec.rule_id) - - -def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str: - return "global/%s/%s" % (spec.template, rule_id) - - -class InvalidRuleException(Exception): - pass - - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/server.py b/synapse/server.py index 37c72bd83a..d49c76518a 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -91,6 +91,7 @@ from synapse.handlers.presence import ( WorkerPresenceHandler, ) from synapse.handlers.profile import ProfileHandler +from synapse.handlers.push_rules import PushRulesHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.register import RegistrationHandler @@ -811,6 +812,10 @@ class HomeServer(metaclass=abc.ABCMeta): return AccountHandler(self) @cache_in_self + def get_push_rules_handler(self) -> PushRulesHandler: + return PushRulesHandler(self) + + @cache_in_self def get_outbound_redis_connection(self) -> "ConnectionHandler": """ The Redis connection used for replication. diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 92539f5d41..eb85bbd392 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -16,7 +16,7 @@ import abc import logging from typing import TYPE_CHECKING, Dict, List, Tuple, Union -from synapse.api.errors import NotFoundError, StoreError +from synapse.api.errors import StoreError from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json @@ -618,7 +618,7 @@ class PushRuleStore(PushRulesWorkerStore): are always stored in the database `push_rules` table). Raises: - NotFoundError if the rule does not exist. + RuleNotFoundException if the rule does not exist. """ async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() @@ -668,8 +668,7 @@ class PushRuleStore(PushRulesWorkerStore): ) txn.execute(sql, (user_id, rule_id)) if txn.fetchone() is None: - # needed to set NOT_FOUND code. - raise NotFoundError("Push rule does not exist.") + raise RuleNotFoundException("Push rule does not exist.") self.db_pool.simple_upsert_txn( txn, @@ -698,9 +697,6 @@ class PushRuleStore(PushRulesWorkerStore): """ Sets the `actions` state of a push rule. - Will throw NotFoundError if the rule does not exist; the Code for this - is NOT_FOUND. - Args: user_id: the user ID of the user who wishes to enable/disable the rule e.g. '@tina:example.org' @@ -712,6 +708,9 @@ class PushRuleStore(PushRulesWorkerStore): is_default_rule: True if and only if this is a server-default rule. This skips the check for existence (as only user-created rules are always stored in the database `push_rules` table). + + Raises: + RuleNotFoundException if the rule does not exist. """ actions_json = json_encoder.encode(actions) @@ -744,7 +743,7 @@ class PushRuleStore(PushRulesWorkerStore): except StoreError as serr: if serr.code == 404: # this sets the NOT_FOUND error Code - raise NotFoundError("Push rule does not exist") + raise RuleNotFoundException("Push rule does not exist") else: raise diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 9fd5d59c55..8bc84aaaca 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -19,8 +19,9 @@ from synapse.api.constants import EduTypes, EventTypes from synapse.events import EventBase from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState +from synapse.handlers.push_rules import InvalidRuleException from synapse.rest import admin -from synapse.rest.client import login, presence, profile, room +from synapse.rest.client import login, notifications, presence, profile, room from synapse.types import create_requester from tests.events.test_presence_router import send_presence_update, sync_presence @@ -38,6 +39,7 @@ class ModuleApiTestCase(HomeserverTestCase): room.register_servlets, presence.register_servlets, profile.register_servlets, + notifications.register_servlets, ] def prepare(self, reactor, clock, homeserver): @@ -553,6 +555,86 @@ class ModuleApiTestCase(HomeserverTestCase): self.assertEqual(state[("org.matrix.test", "")].state_key, "") self.assertEqual(state[("org.matrix.test", "")].content, {}) + def test_set_push_rules_action(self) -> None: + """Test that a module can change the actions of an existing push rule for a user.""" + + # Create a room with 2 users in it. Push rules must not match if the user is the + # event's sender, so we need one user to send messages and one user to receive + # notifications. + user_id = self.register_user("user", "password") + tok = self.login("user", "password") + + room_id = self.helper.create_room_as(user_id, is_public=True, tok=tok) + + user_id2 = self.register_user("user2", "password") + tok2 = self.login("user2", "password") + self.helper.join(room_id, user_id2, tok=tok2) + + # Register a 3rd user and join them to the room, so that we don't accidentally + # trigger 1:1 push rules. + user_id3 = self.register_user("user3", "password") + tok3 = self.login("user3", "password") + self.helper.join(room_id, user_id3, tok=tok3) + + # Send a message as the second user and check that it notifies. + res = self.helper.send(room_id=room_id, body="here's a message", tok=tok2) + event_id = res["event_id"] + + channel = self.make_request( + "GET", + "/notifications", + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body) + self.assertEqual( + channel.json_body["notifications"][0]["event"]["event_id"], + event_id, + channel.json_body, + ) + + # Change the .m.rule.message actions to not notify on new messages. + self.get_success( + defer.ensureDeferred( + self.module_api.set_push_rule_action( + user_id=user_id, + scope="global", + kind="underride", + rule_id=".m.rule.message", + actions=["dont_notify"], + ) + ) + ) + + # Send another message as the second user and check that the number of + # notifications didn't change. + self.helper.send(room_id=room_id, body="here's another message", tok=tok2) + + channel = self.make_request( + "GET", + "/notifications?from=", + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body) + + def test_check_push_rules_actions(self) -> None: + """Test that modules can check whether a list of push rules actions are spec + compliant. + """ + with self.assertRaises(InvalidRuleException): + self.module_api.check_push_rule_actions(["foo"]) + + with self.assertRaises(InvalidRuleException): + self.module_api.check_push_rule_actions({"foo": "bar"}) + + self.module_api.check_push_rule_actions(["notify"]) + + self.module_api.check_push_rule_actions( + [{"set_tweak": "sound", "value": "default"}] + ) + class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase): """For testing ModuleApi functionality in a multi-worker setup""" |