diff options
author | Patrick Cloke <clokep@users.noreply.github.com> | 2021-09-03 09:22:22 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-03 09:22:22 -0400 |
commit | ecbfa4fe4fa7625dec14ef8f9bd06cc4ad141de0 (patch) | |
tree | 732bb7811ccffe2c68fb59fbb8351486ff786d77 /synapse/rest/client/push_rule.py | |
parent | Fix bug with reusing 'txn' when persisting event. (#10743) (diff) | |
download | synapse-ecbfa4fe4fa7625dec14ef8f9bd06cc4ad141de0.tar.xz |
Additional type hints for client REST servlets (part 5) (#10736)
Additionally this enforce type hints on all function signatures inside of the synapse.rest.client package.
Diffstat (limited to 'synapse/rest/client/push_rule.py')
-rw-r--r-- | synapse/rest/client/push_rule.py | 112 |
1 files changed, 69 insertions, 43 deletions
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 702b351d18..fb3211bf3a 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -12,22 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union + +import attr + from synapse.api.errors import ( NotFoundError, StoreError, SynapseError, UnrecognizedRequestError, ) +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_json_value_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest.client._base import client_patterns from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RuleSpec: + scope: str + template: str + rule_id: str + attr: Optional[str] class PushRuleRestServlet(RestServlet): @@ -36,7 +54,7 @@ class PushRuleRestServlet(RestServlet): "Unrecognised request: You probably wanted a trailing slash" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -45,7 +63,7 @@ class PushRuleRestServlet(RestServlet): self._users_new_default_push_rules = hs.config.users_new_default_push_rules - async def on_PUT(self, request, path): + async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") @@ -57,25 +75,25 @@ class PushRuleRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) - if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: + if "/" in spec.rule_id or "\\" in spec.rule_id: raise SynapseError(400, "rule_id may not contain slashes") content = parse_json_value_from_request(request) user_id = requester.user.to_string() - if "attr" in spec: + if spec.attr: await self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) return 200, {} - if spec["rule_id"].startswith("."): + if spec.rule_id.startswith("."): # Rule ids starting with '.' are reserved for server default rules. raise SynapseError(400, "cannot add new rule_ids that start with '.'") try: (conditions, actions) = _rule_tuple_from_request_object( - spec["template"], spec["rule_id"], content + spec.template, spec.rule_id, content ) except InvalidRuleException as e: raise SynapseError(400, str(e)) @@ -106,7 +124,9 @@ class PushRuleRestServlet(RestServlet): return 200, {} - async def on_DELETE(self, request, path): + async def on_DELETE( + self, request: SynapseRequest, path: str + ) -> Tuple[int, JsonDict]: if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") @@ -127,7 +147,7 @@ class PushRuleRestServlet(RestServlet): else: raise - async def on_GET(self, request, path): + async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -138,40 +158,42 @@ class PushRuleRestServlet(RestServlet): rules = format_push_rules_for_user(requester.user, rules) - path = path.split("/")[1:] + path_parts = path.split("/")[1:] - if path == []: + if path_parts == []: # we're a reference impl: pedantry is our job. raise UnrecognizedRequestError( PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == "": + if path_parts[0] == "": return 200, rules - elif path[0] == "global": - result = _filter_ruleset_with_path(rules["global"], path[1:]) + elif path_parts[0] == "global": + result = _filter_ruleset_with_path(rules["global"], path_parts[1:]) return 200, result else: raise UnrecognizedRequestError() - def notify_user(self, user_id): + def notify_user(self, user_id: str) -> None: stream_id = self.store.get_max_push_rules_stream_id() self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) - async def set_rule_attr(self, user_id, spec, val): - if spec["attr"] not in ("enabled", "actions"): + async def set_rule_attr( + self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict] + ) -> None: + if spec.attr not in ("enabled", "actions"): # for the sake of potential future expansion, shouldn't report # 404 in the case of an unknown request so check it corresponds to # a known attribute first. raise UnrecognizedRequestError() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec["rule_id"] + rule_id = spec.rule_id is_default_rule = rule_id.startswith(".") if is_default_rule: if namespaced_rule_id not in BASE_RULE_IDS: raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) - if spec["attr"] == "enabled": + if spec.attr == "enabled": if isinstance(val, dict) and "enabled" in val: val = val["enabled"] if not isinstance(val, bool): @@ -179,14 +201,18 @@ class PushRuleRestServlet(RestServlet): # This should *actually* take a dict, but many clients pass # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") - return await self.store.set_push_rule_enabled( + await self.store.set_push_rule_enabled( user_id, namespaced_rule_id, val, is_default_rule ) - elif spec["attr"] == "actions": + elif spec.attr == "actions": + if not isinstance(val, dict): + raise SynapseError(400, "Value must be a dict") actions = val.get("actions") + if not isinstance(actions, list): + raise SynapseError(400, "Value for 'actions' must be dict") _check_actions(actions) namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec["rule_id"] + rule_id = spec.rule_id is_default_rule = rule_id.startswith(".") if is_default_rule: if user_id in self._users_new_default_push_rules: @@ -196,22 +222,21 @@ class PushRuleRestServlet(RestServlet): if namespaced_rule_id not in rule_ids: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return await self.store.set_push_rule_actions( + await self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule ) else: raise UnrecognizedRequestError() -def _rule_spec_from_path(path): +def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec: """Turn a sequence of path components into a rule spec Args: - path (sequence[unicode]): the URL path components. + path: the URL path components. Returns: - dict: rule spec dict, containing scope/template/rule_id entries, - and possibly attr. + rule spec, containing scope/template/rule_id entries, and possibly attr. Raises: UnrecognizedRequestError if the path components cannot be parsed. @@ -237,17 +262,18 @@ def _rule_spec_from_path(path): rule_id = path[0] - spec = {"scope": scope, "template": template, "rule_id": rule_id} - path = path[1:] + attr = None if len(path) > 0 and len(path[0]) > 0: - spec["attr"] = path[0] + attr = path[0] - return spec + return RuleSpec(scope, template, rule_id, attr) -def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): +def _rule_tuple_from_request_object( + rule_template: str, rule_id: str, req_obj: JsonDict +) -> Tuple[List[JsonDict], List[Union[str, JsonDict]]]: if rule_template in ["override", "underride"]: if "conditions" not in req_obj: raise InvalidRuleException("Missing 'conditions'") @@ -277,7 +303,7 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): return conditions, actions -def _check_actions(actions): +def _check_actions(actions: List[Union[str, JsonDict]]) -> None: if not isinstance(actions, list): raise InvalidRuleException("No actions found") @@ -290,7 +316,7 @@ def _check_actions(actions): raise InvalidRuleException("Unrecognised action") -def _filter_ruleset_with_path(ruleset, path): +def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict: if path == []: raise UnrecognizedRequestError( PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR @@ -315,7 +341,7 @@ def _filter_ruleset_with_path(ruleset, path): if r["rule_id"] == rule_id: the_rule = r if the_rule is None: - raise NotFoundError + raise NotFoundError() path = path[1:] if len(path) == 0: @@ -330,25 +356,25 @@ def _filter_ruleset_with_path(ruleset, path): raise UnrecognizedRequestError() -def _priority_class_from_spec(spec): - if spec["template"] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec["template"])) - pc = PRIORITY_CLASS_MAP[spec["template"]] +def _priority_class_from_spec(spec: RuleSpec) -> int: + if spec.template not in PRIORITY_CLASS_MAP.keys(): + raise InvalidRuleException("Unknown template: %s" % (spec.template)) + pc = PRIORITY_CLASS_MAP[spec.template] return pc -def _namespaced_rule_id_from_spec(spec): - return _namespaced_rule_id(spec, spec["rule_id"]) +def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str: + return _namespaced_rule_id(spec, spec.rule_id) -def _namespaced_rule_id(spec, rule_id): - return "global/%s/%s" % (spec["template"], rule_id) +def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str: + return "global/%s/%s" % (spec.template, rule_id) class InvalidRuleException(Exception): pass -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PushRuleRestServlet(hs).register(http_server) |