summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py1
-rw-r--r--synapse/storage/push_rule.py110
2 files changed, 54 insertions, 57 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index ee5587c721..9e348590ba 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -308,6 +308,7 @@ class SQLBaseStore(object):
         self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
         self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
         self._pushers_id_gen = IdGenerator("pushers", "id", self)
+        self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
 
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index da23c1a114..34805e276e 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -19,7 +19,6 @@ from ._base import SQLBaseStore, Table
 from twisted.internet import defer
 
 import logging
-import copy
 import simplejson as json
 
 logger = logging.getLogger(__name__)
@@ -28,46 +27,45 @@ logger = logging.getLogger(__name__)
 class PushRuleStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_push_rules_for_user(self, user_name):
-        sql = (
-            "SELECT "+",".join(PushRuleTable.fields)+" "
-            "FROM "+PushRuleTable.table_name+" "
-            "WHERE user_name = ? "
-            "ORDER BY priority_class DESC, priority DESC"
+        rows = yield self._simple_select_list(
+            table=PushRuleTable.table_name,
+            keyvalues={
+                "user_name": user_name,
+            },
+            retcols=PushRuleTable.fields,
         )
-        rows = yield self._execute("get_push_rules_for_user", None, sql, user_name)
 
-        dicts = []
-        for r in rows:
-            d = {}
-            for i, f in enumerate(PushRuleTable.fields):
-                d[f] = r[i]
-            dicts.append(d)
+        rows.sort(
+            key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
+        )
 
-        defer.returnValue(dicts)
+        defer.returnValue(rows)
 
     @defer.inlineCallbacks
     def get_push_rules_enabled_for_user(self, user_name):
         results = yield self._simple_select_list(
-            PushRuleEnableTable.table_name,
-            {'user_name': user_name},
-            PushRuleEnableTable.fields,
+            table=PushRuleEnableTable.table_name,
+            keyvalues={
+                'user_name': user_name
+            },
+            retcols=PushRuleEnableTable.fields,
             desc="get_push_rules_enabled_for_user",
         )
-        defer.returnValue(
-            {r['rule_id']: False if r['enabled'] == 0 else True for r in results}
-        )
+        defer.returnValue({
+            r['rule_id']: False if r['enabled'] == 0 else True for r in results
+        })
 
     @defer.inlineCallbacks
     def add_push_rule(self, before, after, **kwargs):
-        vals = copy.copy(kwargs)
+        vals = kwargs
         if 'conditions' in vals:
             vals['conditions'] = json.dumps(vals['conditions'])
         if 'actions' in vals:
             vals['actions'] = json.dumps(vals['actions'])
+
         # we could check the rest of the keys are valid column names
         # but sqlite will do that anyway so I think it's just pointless.
-        if 'id' in vals:
-            del vals['id']
+        vals.pop("id", None)
 
         if before or after:
             ret = yield self.runInteraction(
@@ -87,39 +85,39 @@ class PushRuleStore(SQLBaseStore):
             defer.returnValue(ret)
 
     def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
-        after = None
-        relative_to_rule = None
-        if 'after' in kwargs and kwargs['after']:
-            after = kwargs['after']
-            relative_to_rule = after
-        if 'before' in kwargs and kwargs['before']:
-            relative_to_rule = kwargs['before']
-
-        # get the priority of the rule we're inserting after/before
-        sql = (
-            "SELECT priority_class, priority FROM ? "
-            "WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,)
+        after = kwargs.pop("after", None)
+        relative_to_rule = kwargs.pop("before", after)
+
+        res = self._simple_select_one_txn(
+            txn,
+            table=PushRuleTable.table_name,
+            keyvalues={
+                "user_name": user_name,
+                "rule_id": relative_to_rule,
+            },
+            retcols=["priority_class", "priority"],
+            allow_none=True,
         )
-        txn.execute(sql, (user_name, relative_to_rule))
-        res = txn.fetchall()
+
         if not res:
             raise RuleNotFoundException(
                 "before/after rule not found: %s" % (relative_to_rule,)
             )
-        priority_class, base_rule_priority = res[0]
+
+        priority_class = res["priority_class"]
+        base_rule_priority = res["priority"]
 
         if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
             raise InconsistentRuleException(
                 "Given priority class does not match class of relative rule"
             )
 
-        new_rule = copy.copy(kwargs)
-        if 'before' in new_rule:
-            del new_rule['before']
-        if 'after' in new_rule:
-            del new_rule['after']
+        new_rule = kwargs
+        new_rule.pop("before", None)
+        new_rule.pop("after", None)
         new_rule['priority_class'] = priority_class
         new_rule['user_name'] = user_name
+        new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
 
         # check if the priority before/after is free
         new_rule_priority = base_rule_priority
@@ -153,12 +151,11 @@ class PushRuleStore(SQLBaseStore):
 
             txn.execute(sql, (user_name, priority_class, new_rule_priority))
 
-        # now insert the new rule
-        sql = "INSERT INTO "+PushRuleTable.table_name+" ("
-        sql += ",".join(new_rule.keys())+") VALUES ("
-        sql += ", ".join(["?" for _ in new_rule.keys()])+")"
-
-        txn.execute(sql, new_rule.values())
+        self._simple_insert_txn(
+            txn,
+            table=PushRuleTable.table_name,
+            values=new_rule,
+        )
 
     def _add_push_rule_highest_priority_txn(self, txn, user_name,
                                             priority_class, **kwargs):
@@ -176,18 +173,17 @@ class PushRuleStore(SQLBaseStore):
             new_prio = highest_prio + 1
 
         # and insert the new rule
-        new_rule = copy.copy(kwargs)
-        if 'id' in new_rule:
-            del new_rule['id']
+        new_rule = kwargs
+        new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
         new_rule['user_name'] = user_name
         new_rule['priority_class'] = priority_class
         new_rule['priority'] = new_prio
 
-        sql = "INSERT INTO "+PushRuleTable.table_name+" ("
-        sql += ",".join(new_rule.keys())+") VALUES ("
-        sql += ", ".join(["?" for _ in new_rule.keys()])+")"
-
-        txn.execute(sql, new_rule.values())
+        self._simple_insert_txn(
+            txn,
+            table=PushRuleTable.table_name,
+            values=new_rule,
+        )
 
     @defer.inlineCallbacks
     def delete_push_rule(self, user_name, rule_id):