summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/client/v1/push_rule.py44
1 files changed, 29 insertions, 15 deletions
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 970a019223..cf68725ca1 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -36,6 +36,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
     SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
         "Unrecognised request: You probably wanted a trailing slash")
 
+    def __init__(self, hs):
+        super(PushRuleRestServlet, self).__init__(hs)
+        self.store = hs.get_datastore()
+        self.notifier = hs.get_notifier()
+
     @defer.inlineCallbacks
     def on_PUT(self, request):
         spec = _rule_spec_from_path(request.postpath)
@@ -51,8 +56,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
 
         content = _parse_json(request)
 
+        user_id = requester.user.to_string()
+
         if 'attr' in spec:
-            yield self.set_rule_attr(requester.user.to_string(), spec, content)
+            yield self.set_rule_attr(user_id, spec, content)
+            self.notify_user(user_id)
             defer.returnValue((200, {}))
 
         if spec['rule_id'].startswith('.'):
@@ -77,8 +85,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
             after = _namespaced_rule_id(spec, after[0])
 
         try:
-            yield self.hs.get_datastore().add_push_rule(
-                user_id=requester.user.to_string(),
+            yield self.store.add_push_rule(
+                user_id=user_id,
                 rule_id=_namespaced_rule_id_from_spec(spec),
                 priority_class=priority_class,
                 conditions=conditions,
@@ -86,6 +94,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
                 before=before,
                 after=after
             )
+            self.notify_user(user_id)
         except InconsistentRuleException as e:
             raise SynapseError(400, e.message)
         except RuleNotFoundException as e:
@@ -98,13 +107,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
         spec = _rule_spec_from_path(request.postpath)
 
         requester = yield self.auth.get_user_by_req(request)
+        user_id = requester.user.to_string()
 
         namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
 
         try:
-            yield self.hs.get_datastore().delete_push_rule(
-                requester.user.to_string(), namespaced_rule_id
+            yield self.store.delete_push_rule(
+                user_id, namespaced_rule_id
             )
+            self.notify_user(user_id)
             defer.returnValue((200, {}))
         except StoreError as e:
             if e.code == 404:
@@ -115,14 +126,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_GET(self, request):
         requester = yield self.auth.get_user_by_req(request)
-        user = requester.user
+        user_id = requester.user.to_string()
 
         # we build up the full structure and then decide which bits of it
         # to send which means doing unnecessary work sometimes but is
         # is probably not going to make a whole lot of difference
-        rawrules = yield self.hs.get_datastore().get_push_rules_for_user(
-            user.to_string()
-        )
+        rawrules = yield self.store.get_push_rules_for_user(user_id)
 
         ruleslist = []
         for rawrule in rawrules:
@@ -138,8 +147,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
 
         rules['global'] = _add_empty_priority_class_arrays(rules['global'])
 
-        enabled_map = yield self.hs.get_datastore().\
-            get_push_rules_enabled_for_user(user.to_string())
+        enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
 
         for r in ruleslist:
             rulearray = None
@@ -152,9 +160,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
 
                 pattern_type = c.pop("pattern_type", None)
                 if pattern_type == "user_id":
-                    c["pattern"] = user.to_string()
+                    c["pattern"] = user_id
                 elif pattern_type == "user_localpart":
-                    c["pattern"] = user.localpart
+                    c["pattern"] = requester.user.localpart
 
             rulearray = rules['global'][template_name]
 
@@ -188,6 +196,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
     def on_OPTIONS(self, _):
         return 200, {}
 
+    def notify_user(self, user_id):
+        stream_id = self.store.get_push_rules_stream_token()
+        self.notifier.on_new_event(
+            "push_rules_key", stream_id, users=[user_id]
+        )
+
     def set_rule_attr(self, user_id, spec, val):
         if spec['attr'] == 'enabled':
             if isinstance(val, dict) and "enabled" in val:
@@ -198,7 +212,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
                 # bools directly, so let's not break them.
                 raise SynapseError(400, "Value for 'enabled' must be boolean")
             namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
-            return self.hs.get_datastore().set_push_rule_enabled(
+            return self.store.set_push_rule_enabled(
                 user_id, namespaced_rule_id, val
             )
         elif spec['attr'] == 'actions':
@@ -210,7 +224,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
             if is_default_rule:
                 if namespaced_rule_id not in BASE_RULE_IDS:
                     raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
-            return self.hs.get_datastore().set_push_rule_actions(
+            return self.store.set_push_rule_actions(
                 user_id, namespaced_rule_id, actions, is_default_rule
             )
         else: