diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index e781a3bcf4..ddf8ed5e9c 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -163,6 +163,18 @@ class PushRuleRestServlet(RestServlet):
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
async def set_rule_attr(self, user_id, spec, val):
+ if spec["attr"] not in ("enabled", "actions"):
+ # for the sake of potential future expansion, shouldn't report
+ # 404 in the case of an unknown request so check it corresponds to
+ # a known attribute first.
+ raise UnrecognizedRequestError()
+
+ namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+ rule_id = spec["rule_id"]
+ is_default_rule = rule_id.startswith(".")
+ if is_default_rule:
+ if namespaced_rule_id not in BASE_RULE_IDS:
+ raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
if spec["attr"] == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
@@ -171,9 +183,8 @@ class PushRuleRestServlet(RestServlet):
# This should *actually* take a dict, but many clients pass
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
- namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
return await self.store.set_push_rule_enabled(
- user_id, namespaced_rule_id, val
+ user_id, namespaced_rule_id, val, is_default_rule
)
elif spec["attr"] == "actions":
actions = val.get("actions")
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 72eaaad8b6..6aba435b92 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -362,20 +362,19 @@ class PasswordRestServlet(RestServlet):
shadow_user = UserID(
requester.user.localpart, self.hs.config.shadow_server.get("hs")
)
- self.shadow_password(params, shadow_user.to_string())
+ await self.shadow_password(params, shadow_user.to_string())
return 200, {}
def on_OPTIONS(self, _):
return 200, {}
- @defer.inlineCallbacks
- def shadow_password(self, body, user_id):
+ async def shadow_password(self, body, user_id):
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
- yield self.http_client.post_json_get_json(
+ await self.http_client.post_json_get_json(
"%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s"
% (shadow_hs_url, as_token, user_id),
body,
@@ -756,7 +755,7 @@ class ThreepidRestServlet(RestServlet):
shadow_user = UserID(
requester.user.localpart, self.hs.config.shadow_server.get("hs")
)
- self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+ await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
return 200, {}
@@ -791,7 +790,7 @@ class ThreepidRestServlet(RestServlet):
"address": validation_session["address"],
"validated_at": validation_session["validated_at"],
}
- self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+ await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
return 200, {}
@@ -799,13 +798,12 @@ class ThreepidRestServlet(RestServlet):
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
- @defer.inlineCallbacks
- def shadow_3pid(self, body, user_id):
+ async def shadow_3pid(self, body, user_id):
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
- yield self.http_client.post_json_get_json(
+ await self.http_client.post_json_get_json(
"%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
% (shadow_hs_url, as_token, user_id),
body,
@@ -866,20 +864,19 @@ class ThreepidAddRestServlet(RestServlet):
"address": validation_session["address"],
"validated_at": validation_session["validated_at"],
}
- self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+ await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
- @defer.inlineCallbacks
- def shadow_3pid(self, body, user_id):
+ async def shadow_3pid(self, body, user_id):
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
- yield self.http_client.post_json_get_json(
+ await self.http_client.post_json_get_json(
"%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
% (shadow_hs_url, as_token, user_id),
body,
@@ -983,7 +980,7 @@ class ThreepidDeleteRestServlet(RestServlet):
shadow_user = UserID(
requester.user.localpart, self.hs.config.shadow_server.get("hs")
)
- self.shadow_3pid_delete(body, shadow_user.to_string())
+ await self.shadow_3pid_delete(body, shadow_user.to_string())
if ret:
id_server_unbind_result = "success"
@@ -992,13 +989,12 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
- @defer.inlineCallbacks
- def shadow_3pid_delete(self, body, user_id):
+ async def shadow_3pid_delete(self, body, user_id):
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
- yield self.http_client.post_json_get_json(
+ await self.http_client.post_json_get_json(
"%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s"
% (shadow_hs_url, as_token, user_id),
body,
@@ -1101,45 +1097,6 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str):
)
-def assert_valid_next_link(hs: "HomeServer", next_link: str):
- """
- Raises a SynapseError if a given next_link value is invalid
-
- next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config
- option is either empty or contains a domain that matches the one in the given next_link
-
- Args:
- hs: The homeserver object
- next_link: The next_link value given by the client
-
- Raises:
- SynapseError: If the next_link is invalid
- """
- valid = True
-
- # Parse the contents of the URL
- next_link_parsed = urlparse(next_link)
-
- # Scheme must not point to the local drive
- if next_link_parsed.scheme == "file":
- valid = False
-
- # If the domain whitelist is set, the domain must be in it
- if (
- valid
- and hs.config.next_link_domain_whitelist is not None
- and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist
- ):
- valid = False
-
- if not valid:
- raise SynapseError(
- 400,
- "'next_link' domain not included in whitelist, or not http(s)",
- errcode=Codes.INVALID_PARAM,
- )
-
-
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
|