diff --git a/synapse/handlers/push_rules.py b/synapse/handlers/push_rules.py
index 4ef6a04c51..782765223c 100644
--- a/synapse/handlers/push_rules.py
+++ b/synapse/handlers/push_rules.py
@@ -18,15 +18,15 @@
# [This file includes modifications made by New Vector Limited]
#
#
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
-
-import attr
+from typing import TYPE_CHECKING, Any, Dict, List, Union
from synapse.api.errors import SynapseError, UnrecognizedRequestError
from synapse.push.clientformat import format_push_rules_for_user
+from synapse.replication.http.push import PushSetRuleAttrRestServlet
from synapse.storage.push_rule import RuleNotFoundException
from synapse.synapse_rust.push import get_base_rule_ids
from synapse.types import JsonDict, StreamKeyType, UserID
+from synapse.types.push import RuleSpec
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -35,14 +35,6 @@ if TYPE_CHECKING:
BASE_RULE_IDS = get_base_rule_ids()
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class RuleSpec:
- scope: str
- template: str
- rule_id: str
- attr: Optional[str]
-
-
class PushRulesHandler:
"""A class to handle changes in push rules for users."""
@@ -50,9 +42,28 @@ class PushRulesHandler:
self._notifier = hs.get_notifier()
self._main_store = hs.get_datastores().main
+ self._push_attr_repl_client = None
+ if hs.config.worker.worker_app is not None:
+ self._push_attr_repl_client = PushSetRuleAttrRestServlet.make_client(hs)
+
async def set_rule_attr(
self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
) -> None:
+ if self._push_attr_repl_client:
+ await self._push_attr_repl_client(
+ user_id=user_id,
+ scope=spec.scope,
+ template=spec.template,
+ rule_id=spec.rule_id,
+ attr=spec.attr,
+ val=val,
+ )
+ else:
+ await self._set_rule_attr(user_id, spec, val)
+
+ async def _set_rule_attr(
+ self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
+ ) -> None:
"""Set an attribute (enabled or actions) on an existing push rule.
Notifies listeners (e.g. sync handler) of the change.
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index f6bfd93d3c..c871b33858 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -66,7 +66,7 @@ from synapse.handlers.auth import (
AuthHandler,
)
from synapse.handlers.device import DeviceHandler
-from synapse.handlers.push_rules import RuleSpec, check_actions
+from synapse.handlers.push_rules import check_actions
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
DirectServeHtmlResource,
@@ -141,6 +141,7 @@ from synapse.types import (
UserProfile,
create_requester,
)
+from synapse.types.push import RuleSpec
from synapse.types.state import StateFilter
from synapse.util import Clock
from synapse.util.async_helpers import maybe_awaitable
diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py
index 8e5641707a..06fb7386dd 100644
--- a/synapse/replication/http/push.py
+++ b/synapse/replication/http/push.py
@@ -20,13 +20,14 @@
#
import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Tuple, Union
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
+from synapse.types.push import RuleSpec
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -77,5 +78,59 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
return 200, {}
+class PushSetRuleAttrRestServlet(ReplicationEndpoint):
+ """Updates an attr of a push rule
+
+ Request format:
+
+ POST /_synapse/replication/push_set_rule_attr/:user_id/:scope/:template/:rule_id/:attr
+
+ {
+ "vale": <new_val>,
+ }
+
+ """
+
+ NAME = "push_set_rule_attr"
+ PATH_ARGS = ("user_id", "scope", "template", "rule_id", "attr")
+ CACHE = False
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self._push_rules_handler = hs.get_push_rules_handler()
+
+ @staticmethod
+ async def _serialize_payload( # type: ignore[override]
+ user_id: str,
+ scope: str,
+ template: str,
+ rule_id: str,
+ attr: str,
+ val: Union[bool, JsonDict],
+ ) -> JsonDict:
+ payload = {"val": val}
+
+ return payload
+
+ async def _handle_request( # type: ignore[override]
+ self,
+ request: Request,
+ content: JsonDict,
+ user_id: str,
+ scope: str,
+ template: str,
+ rule_id: str,
+ attr: str,
+ ) -> Tuple[int, JsonDict]:
+
+ spec = RuleSpec(scope=scope, template=template, rule_id=rule_id, attr=attr)
+
+ await self._push_rules_handler.set_rule_attr(user_id, spec, val=content["val"])
+
+ return 200, {}
+
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationRemovePusherRestServlet(hs).register(http_server)
+ PushSetRuleAttrRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 7d58611abb..189b7a64e3 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -27,7 +27,7 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
-from synapse.handlers.push_rules import InvalidRuleException, RuleSpec, check_actions
+from synapse.handlers.push_rules import InvalidRuleException, check_actions
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -39,6 +39,7 @@ from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.types import JsonDict
+from synapse.types.push import RuleSpec
from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING:
diff --git a/synapse/types/push.py b/synapse/types/push.py
new file mode 100644
index 0000000000..ec2324c4b6
--- /dev/null
+++ b/synapse/types/push.py
@@ -0,0 +1,25 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+
+
+from typing import Optional
+
+import attr
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RuleSpec:
+ scope: str
+ template: str
+ rule_id: str
+ attr: Optional[str]
|