summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/push_rule.py168
1 files changed, 86 insertions, 82 deletions
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index f9a48171ba..e19a81e41f 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -99,38 +99,36 @@ class PushRuleStore(SQLBaseStore):
             results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
         defer.returnValue(results)
 
-    @defer.inlineCallbacks
-    def add_push_rule(self, before, after, **kwargs):
-        vals = kwargs
-        if 'conditions' in vals:
-            vals['conditions'] = json.dumps(vals['conditions'])
-        if 'actions' in vals:
-            vals['actions'] = json.dumps(vals['actions'])
-
-        # we could check the rest of the keys are valid column names
-        # but sqlite will do that anyway so I think it's just pointless.
-        vals.pop("id", None)
+    def add_push_rule(
+        self, user_id, rule_id, priority_class, conditions, actions,
+        before=None, after=None
+    ):
+        conditions_json = json.dumps(conditions)
+        actions_json = json.dumps(actions)
 
         if before or after:
-            ret = yield self.runInteraction(
+            return self.runInteraction(
                 "_add_push_rule_relative_txn",
                 self._add_push_rule_relative_txn,
-                before=before,
-                after=after,
-                **vals
+                user_id, rule_id, priority_class,
+                conditions_json, actions_json, before, after,
             )
-            defer.returnValue(ret)
         else:
-            ret = yield self.runInteraction(
+            return self.runInteraction(
                 "_add_push_rule_highest_priority_txn",
                 self._add_push_rule_highest_priority_txn,
-                **vals
+                user_id, rule_id, priority_class,
+                conditions_json, actions_json,
             )
-            defer.returnValue(ret)
 
-    def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
-        after = kwargs.pop("after", None)
-        before = kwargs.pop("before", None)
+    def _add_push_rule_relative_txn(
+        self, txn, user_id, rule_id, priority_class,
+        conditions_json, actions_json, before, after
+    ):
+        # Lock the table since otherwise we'll have annoying races between the
+        # SELECT here and the UPSERT below.
+        self.database_engine.lock_table(txn, "push_rules")
+
         relative_to_rule = before or after
 
         res = self._simple_select_one_txn(
@@ -149,69 +147,45 @@ class PushRuleStore(SQLBaseStore):
                 "before/after rule not found: %s" % (relative_to_rule,)
             )
 
-        priority_class = res["priority_class"]
+        base_priority_class = res["priority_class"]
         base_rule_priority = res["priority"]
 
-        if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
+        if base_priority_class != priority_class:
             raise InconsistentRuleException(
                 "Given priority class does not match class of relative rule"
             )
 
-        new_rule = kwargs
-        new_rule.pop("before", None)
-        new_rule.pop("after", None)
-        new_rule['priority_class'] = priority_class
-        new_rule['user_name'] = user_id
-        new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
-
-        # check if the priority before/after is free
-        new_rule_priority = base_rule_priority
-        if after:
-            new_rule_priority -= 1
+        if before:
+            # Higher priority rules are executed first, So adding a rule before
+            # a rule means giving it a higher priority than that rule.
+            new_rule_priority = base_rule_priority + 1
         else:
-            new_rule_priority += 1
-
-        new_rule['priority'] = new_rule_priority
+            # We increment the priority of the existing rules to make space for
+            # the new rule. Therefore if we want this rule to appear after
+            # an existing rule we give it the priority of the existing rule,
+            # and then increment the priority of the existing rule.
+            new_rule_priority = base_rule_priority
 
         sql = (
-            "SELECT COUNT(*) FROM push_rules"
-            " WHERE user_name = ? AND priority_class = ? AND priority = ?"
+            "UPDATE push_rules SET priority = priority + 1"
+            " WHERE user_name = ? AND priority_class = ? AND priority >= ?"
         )
+
         txn.execute(sql, (user_id, priority_class, new_rule_priority))
-        res = txn.fetchall()
-        num_conflicting = res[0][0]
-
-        # if there are conflicting rules, bump everything
-        if num_conflicting:
-            sql = "UPDATE push_rules SET priority = priority "
-            if after:
-                sql += "-1"
-            else:
-                sql += "+1"
-            sql += " WHERE user_name = ? AND priority_class = ? AND priority "
-            if after:
-                sql += "<= ?"
-            else:
-                sql += ">= ?"
-
-            txn.execute(sql, (user_id, priority_class, new_rule_priority))
 
-        txn.call_after(
-            self.get_push_rules_for_user.invalidate, (user_id,)
+        self._upsert_push_rule_txn(
+            txn, user_id, rule_id, priority_class, new_rule_priority,
+            conditions_json, actions_json,
         )
 
-        txn.call_after(
-            self.get_push_rules_enabled_for_user.invalidate, (user_id,)
-        )
+    def _add_push_rule_highest_priority_txn(
+        self, txn, user_id, rule_id, priority_class,
+        conditions_json, actions_json
+    ):
+        # Lock the table since otherwise we'll have annoying races between the
+        # SELECT here and the UPSERT below.
+        self.database_engine.lock_table(txn, "push_rules")
 
-        self._simple_insert_txn(
-            txn,
-            table="push_rules",
-            values=new_rule,
-        )
-
-    def _add_push_rule_highest_priority_txn(self, txn, user_id,
-                                            priority_class, **kwargs):
         # find the highest priority rule in that class
         sql = (
             "SELECT COUNT(*), MAX(priority) FROM push_rules"
@@ -225,12 +199,48 @@ class PushRuleStore(SQLBaseStore):
         if how_many > 0:
             new_prio = highest_prio + 1
 
-        # and insert the new rule
-        new_rule = kwargs
-        new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
-        new_rule['user_name'] = user_id
-        new_rule['priority_class'] = priority_class
-        new_rule['priority'] = new_prio
+        self._upsert_push_rule_txn(
+            txn,
+            user_id, rule_id, priority_class, new_prio,
+            conditions_json, actions_json,
+        )
+
+    def _upsert_push_rule_txn(
+        self, txn, user_id, rule_id, priority_class,
+        priority, conditions_json, actions_json
+    ):
+        """Specialised version of _simple_upsert_txn that picks a push_rule_id
+        using the _push_rule_id_gen if it needs to insert the rule. It assumes
+        that the "push_rules" table is locked"""
+
+        sql = (
+            "UPDATE push_rules"
+            " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
+            " WHERE user_name = ? AND rule_id = ?"
+        )
+
+        txn.execute(sql, (
+            priority_class, priority, conditions_json, actions_json,
+            user_id, rule_id,
+        ))
+
+        if txn.rowcount == 0:
+            # We didn't update a row with the given rule_id so insert one
+            push_rule_id = self._push_rule_id_gen.get_next_txn(txn)
+
+            self._simple_insert_txn(
+                txn,
+                table="push_rules",
+                values={
+                    "id": push_rule_id,
+                    "user_name": user_id,
+                    "rule_id": rule_id,
+                    "priority_class": priority_class,
+                    "priority": priority,
+                    "conditions": conditions_json,
+                    "actions": actions_json,
+                },
+            )
 
         txn.call_after(
             self.get_push_rules_for_user.invalidate, (user_id,)
@@ -239,12 +249,6 @@ class PushRuleStore(SQLBaseStore):
             self.get_push_rules_enabled_for_user.invalidate, (user_id,)
         )
 
-        self._simple_insert_txn(
-            txn,
-            table="push_rules",
-            values=new_rule,
-        )
-
     @defer.inlineCallbacks
     def delete_push_rule(self, user_id, rule_id):
         """