diff options
Diffstat (limited to 'synapse/rest')
-rw-r--r-- | synapse/rest/client/v1/push_rule.py | 44 |
1 files changed, 29 insertions, 15 deletions
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 970a019223..cf68725ca1 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -36,6 +36,11 @@ class PushRuleRestServlet(ClientV1RestServlet): SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( "Unrecognised request: You probably wanted a trailing slash") + def __init__(self, hs): + super(PushRuleRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + @defer.inlineCallbacks def on_PUT(self, request): spec = _rule_spec_from_path(request.postpath) @@ -51,8 +56,11 @@ class PushRuleRestServlet(ClientV1RestServlet): content = _parse_json(request) + user_id = requester.user.to_string() + if 'attr' in spec: - yield self.set_rule_attr(requester.user.to_string(), spec, content) + yield self.set_rule_attr(user_id, spec, content) + self.notify_user(user_id) defer.returnValue((200, {})) if spec['rule_id'].startswith('.'): @@ -77,8 +85,8 @@ class PushRuleRestServlet(ClientV1RestServlet): after = _namespaced_rule_id(spec, after[0]) try: - yield self.hs.get_datastore().add_push_rule( - user_id=requester.user.to_string(), + yield self.store.add_push_rule( + user_id=user_id, rule_id=_namespaced_rule_id_from_spec(spec), priority_class=priority_class, conditions=conditions, @@ -86,6 +94,7 @@ class PushRuleRestServlet(ClientV1RestServlet): before=before, after=after ) + self.notify_user(user_id) except InconsistentRuleException as e: raise SynapseError(400, e.message) except RuleNotFoundException as e: @@ -98,13 +107,15 @@ class PushRuleRestServlet(ClientV1RestServlet): spec = _rule_spec_from_path(request.postpath) requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.hs.get_datastore().delete_push_rule( - requester.user.to_string(), namespaced_rule_id + yield self.store.delete_push_rule( + user_id, namespaced_rule_id ) + self.notify_user(user_id) defer.returnValue((200, {})) except StoreError as e: if e.code == 404: @@ -115,14 +126,12 @@ class PushRuleRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - user = requester.user + user_id = requester.user.to_string() # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rawrules = yield self.hs.get_datastore().get_push_rules_for_user( - user.to_string() - ) + rawrules = yield self.store.get_push_rules_for_user(user_id) ruleslist = [] for rawrule in rawrules: @@ -138,8 +147,7 @@ class PushRuleRestServlet(ClientV1RestServlet): rules['global'] = _add_empty_priority_class_arrays(rules['global']) - enabled_map = yield self.hs.get_datastore().\ - get_push_rules_enabled_for_user(user.to_string()) + enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) for r in ruleslist: rulearray = None @@ -152,9 +160,9 @@ class PushRuleRestServlet(ClientV1RestServlet): pattern_type = c.pop("pattern_type", None) if pattern_type == "user_id": - c["pattern"] = user.to_string() + c["pattern"] = user_id elif pattern_type == "user_localpart": - c["pattern"] = user.localpart + c["pattern"] = requester.user.localpart rulearray = rules['global'][template_name] @@ -188,6 +196,12 @@ class PushRuleRestServlet(ClientV1RestServlet): def on_OPTIONS(self, _): return 200, {} + def notify_user(self, user_id): + stream_id = self.store.get_push_rules_stream_token() + self.notifier.on_new_event( + "push_rules_key", stream_id, users=[user_id] + ) + def set_rule_attr(self, user_id, spec, val): if spec['attr'] == 'enabled': if isinstance(val, dict) and "enabled" in val: @@ -198,7 +212,7 @@ class PushRuleRestServlet(ClientV1RestServlet): # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - return self.hs.get_datastore().set_push_rule_enabled( + return self.store.set_push_rule_enabled( user_id, namespaced_rule_id, val ) elif spec['attr'] == 'actions': @@ -210,7 +224,7 @@ class PushRuleRestServlet(ClientV1RestServlet): if is_default_rule: if namespaced_rule_id not in BASE_RULE_IDS: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return self.hs.get_datastore().set_push_rule_actions( + return self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule ) else: |