summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/push/baserules.py57
-rw-r--r--synapse/rest/client/v1/push_rule.py31
-rw-r--r--synapse/storage/push_rule.py25
3 files changed, 102 insertions, 11 deletions
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 0832c77cb4..86a2998bcc 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -13,46 +13,67 @@
 # limitations under the License.
 
 from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
+import copy
 
 
 def list_with_base_rules(rawrules):
+    """Combine the list of rules set by the user with the default push rules
+
+    :param list rawrules: The rules the user has modified or set.
+    :returns: A new list with the rules set by the user combined with the
+        defaults.
+    """
     ruleslist = []
 
+    # Grab the base rules that the user has modified.
+    # The modified base rules have a priority_class of -1.
+    modified_base_rules = {
+        r['rule_id']: r for r in rawrules if r['priority_class'] < 0
+    }
+
+    # Remove the modified base rules from the list, They'll be added back
+    # in the default postions in the list.
+    rawrules = [r for r in rawrules if r['priority_class'] >= 0]
+
     # shove the server default rules for each kind onto the end of each
     current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
 
     ruleslist.extend(make_base_prepend_rules(
-        PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+        PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
     ))
 
     for r in rawrules:
         if r['priority_class'] < current_prio_class:
             while r['priority_class'] < current_prio_class:
                 ruleslist.extend(make_base_append_rules(
-                    PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+                    PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+                    modified_base_rules,
                 ))
                 current_prio_class -= 1
                 if current_prio_class > 0:
                     ruleslist.extend(make_base_prepend_rules(
-                        PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+                        PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+                        modified_base_rules,
                     ))
 
         ruleslist.append(r)
 
     while current_prio_class > 0:
         ruleslist.extend(make_base_append_rules(
-            PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+            PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+            modified_base_rules,
         ))
         current_prio_class -= 1
         if current_prio_class > 0:
             ruleslist.extend(make_base_prepend_rules(
-                PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+                PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+                modified_base_rules,
             ))
 
     return ruleslist
 
 
-def make_base_append_rules(kind):
+def make_base_append_rules(kind, modified_base_rules):
     rules = []
 
     if kind == 'override':
@@ -62,15 +83,31 @@ def make_base_append_rules(kind):
     elif kind == 'content':
         rules = BASE_APPEND_CONTENT_RULES
 
+    # Copy the rules before modifying them
+    rules = copy.deepcopy(rules)
+    for r in rules:
+        # Only modify the actions, keep the conditions the same.
+        modified = modified_base_rules.get(r['rule_id'])
+        if modified:
+            r['actions'] = modified['actions']
+
     return rules
 
 
-def make_base_prepend_rules(kind):
+def make_base_prepend_rules(kind, modified_base_rules):
     rules = []
 
     if kind == 'override':
         rules = BASE_PREPEND_OVERRIDE_RULES
 
+    # Copy the rules before modifying them
+    rules = copy.deepcopy(rules)
+    for r in rules:
+        # Only modify the actions, keep the conditions the same.
+        modified = modified_base_rules.get(r['rule_id'])
+        if modified:
+            r['actions'] = modified['actions']
+
     return rules
 
 
@@ -263,18 +300,24 @@ BASE_APPEND_UNDERRIDE_RULES = [
 ]
 
 
+BASE_RULE_IDS = set()
+
 for r in BASE_APPEND_CONTENT_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['content']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
 
 for r in BASE_PREPEND_OVERRIDE_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['override']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
 
 for r in BASE_APPEND_OVRRIDE_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['override']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
 
 for r in BASE_APPEND_UNDERRIDE_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['underride']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index d26e4cde3e..970a019223 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -22,7 +22,7 @@ from .base import ClientV1RestServlet, client_path_patterns
 from synapse.storage.push_rule import (
     InconsistentRuleException, RuleNotFoundException
 )
-import synapse.push.baserules as baserules
+from synapse.push.baserules import list_with_base_rules, BASE_RULE_IDS
 from synapse.push.rulekinds import (
     PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
 )
@@ -55,6 +55,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
             yield self.set_rule_attr(requester.user.to_string(), spec, content)
             defer.returnValue((200, {}))
 
+        if spec['rule_id'].startswith('.'):
+            # Rule ids starting with '.' are reserved for server default rules.
+            raise SynapseError(400, "cannot add new rule_ids that start with '.'")
+
         try:
             (conditions, actions) = _rule_tuple_from_request_object(
                 spec['template'],
@@ -128,7 +132,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
             ruleslist.append(rule)
 
         # We're going to be mutating this a lot, so do a deep copy
-        ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
+        ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
 
         rules = {'global': {}, 'device': {}}
 
@@ -197,6 +201,18 @@ class PushRuleRestServlet(ClientV1RestServlet):
             return self.hs.get_datastore().set_push_rule_enabled(
                 user_id, namespaced_rule_id, val
             )
+        elif spec['attr'] == 'actions':
+            actions = val.get('actions')
+            _check_actions(actions)
+            namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+            rule_id = spec['rule_id']
+            is_default_rule = rule_id.startswith(".")
+            if is_default_rule:
+                if namespaced_rule_id not in BASE_RULE_IDS:
+                    raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
+            return self.hs.get_datastore().set_push_rule_actions(
+                user_id, namespaced_rule_id, actions, is_default_rule
+            )
         else:
             raise UnrecognizedRequestError()
 
@@ -274,6 +290,15 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
         raise InvalidRuleException("No actions found")
     actions = req_obj['actions']
 
+    _check_actions(actions)
+
+    return conditions, actions
+
+
+def _check_actions(actions):
+    if not isinstance(actions, list):
+        raise InvalidRuleException("No actions found")
+
     for a in actions:
         if a in ['notify', 'dont_notify', 'coalesce']:
             pass
@@ -282,8 +307,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
         else:
             raise InvalidRuleException("Unrecognised action")
 
-    return conditions, actions
-
 
 def _add_empty_priority_class_arrays(d):
     for pc in PRIORITY_CLASS_MAP.keys():
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index e19a81e41f..bb5c14d912 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -294,6 +294,31 @@ class PushRuleStore(SQLBaseStore):
             self.get_push_rules_enabled_for_user.invalidate, (user_id,)
         )
 
+    def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+        actions_json = json.dumps(actions)
+
+        def set_push_rule_actions_txn(txn):
+            if is_default_rule:
+                # Add a dummy rule to the rules table with the user specified
+                # actions.
+                priority_class = -1
+                priority = 1
+                self._upsert_push_rule_txn(
+                    txn, user_id, rule_id, priority_class, priority,
+                    "[]", actions_json
+                )
+            else:
+                self._simple_update_one_txn(
+                    txn,
+                    "push_rules",
+                    {'user_name': user_id, 'rule_id': rule_id},
+                    {'actions': actions_json},
+                )
+
+        return self.runInteraction(
+            "set_push_rule_actions", set_push_rule_actions_txn,
+        )
+
 
 class RuleNotFoundException(Exception):
     pass