diff options
Diffstat (limited to 'synapse/storage/push_rule.py')
-rw-r--r-- | synapse/storage/push_rule.py | 110 |
1 files changed, 53 insertions, 57 deletions
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index da23c1a114..34805e276e 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -19,7 +19,6 @@ from ._base import SQLBaseStore, Table from twisted.internet import defer import logging -import copy import simplejson as json logger = logging.getLogger(__name__) @@ -28,46 +27,45 @@ logger = logging.getLogger(__name__) class PushRuleStore(SQLBaseStore): @defer.inlineCallbacks def get_push_rules_for_user(self, user_name): - sql = ( - "SELECT "+",".join(PushRuleTable.fields)+" " - "FROM "+PushRuleTable.table_name+" " - "WHERE user_name = ? " - "ORDER BY priority_class DESC, priority DESC" + rows = yield self._simple_select_list( + table=PushRuleTable.table_name, + keyvalues={ + "user_name": user_name, + }, + retcols=PushRuleTable.fields, ) - rows = yield self._execute("get_push_rules_for_user", None, sql, user_name) - dicts = [] - for r in rows: - d = {} - for i, f in enumerate(PushRuleTable.fields): - d[f] = r[i] - dicts.append(d) + rows.sort( + key=lambda row: (-int(row["priority_class"]), -int(row["priority"])) + ) - defer.returnValue(dicts) + defer.returnValue(rows) @defer.inlineCallbacks def get_push_rules_enabled_for_user(self, user_name): results = yield self._simple_select_list( - PushRuleEnableTable.table_name, - {'user_name': user_name}, - PushRuleEnableTable.fields, + table=PushRuleEnableTable.table_name, + keyvalues={ + 'user_name': user_name + }, + retcols=PushRuleEnableTable.fields, desc="get_push_rules_enabled_for_user", ) - defer.returnValue( - {r['rule_id']: False if r['enabled'] == 0 else True for r in results} - ) + defer.returnValue({ + r['rule_id']: False if r['enabled'] == 0 else True for r in results + }) @defer.inlineCallbacks def add_push_rule(self, before, after, **kwargs): - vals = copy.copy(kwargs) + vals = kwargs if 'conditions' in vals: vals['conditions'] = json.dumps(vals['conditions']) if 'actions' in vals: vals['actions'] = json.dumps(vals['actions']) + # we could check the rest of the keys are valid column names # but sqlite will do that anyway so I think it's just pointless. - if 'id' in vals: - del vals['id'] + vals.pop("id", None) if before or after: ret = yield self.runInteraction( @@ -87,39 +85,39 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(ret) def _add_push_rule_relative_txn(self, txn, user_name, **kwargs): - after = None - relative_to_rule = None - if 'after' in kwargs and kwargs['after']: - after = kwargs['after'] - relative_to_rule = after - if 'before' in kwargs and kwargs['before']: - relative_to_rule = kwargs['before'] - - # get the priority of the rule we're inserting after/before - sql = ( - "SELECT priority_class, priority FROM ? " - "WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,) + after = kwargs.pop("after", None) + relative_to_rule = kwargs.pop("before", after) + + res = self._simple_select_one_txn( + txn, + table=PushRuleTable.table_name, + keyvalues={ + "user_name": user_name, + "rule_id": relative_to_rule, + }, + retcols=["priority_class", "priority"], + allow_none=True, ) - txn.execute(sql, (user_name, relative_to_rule)) - res = txn.fetchall() + if not res: raise RuleNotFoundException( "before/after rule not found: %s" % (relative_to_rule,) ) - priority_class, base_rule_priority = res[0] + + priority_class = res["priority_class"] + base_rule_priority = res["priority"] if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class: raise InconsistentRuleException( "Given priority class does not match class of relative rule" ) - new_rule = copy.copy(kwargs) - if 'before' in new_rule: - del new_rule['before'] - if 'after' in new_rule: - del new_rule['after'] + new_rule = kwargs + new_rule.pop("before", None) + new_rule.pop("after", None) new_rule['priority_class'] = priority_class new_rule['user_name'] = user_name + new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) # check if the priority before/after is free new_rule_priority = base_rule_priority @@ -153,12 +151,11 @@ class PushRuleStore(SQLBaseStore): txn.execute(sql, (user_name, priority_class, new_rule_priority)) - # now insert the new rule - sql = "INSERT INTO "+PushRuleTable.table_name+" (" - sql += ",".join(new_rule.keys())+") VALUES (" - sql += ", ".join(["?" for _ in new_rule.keys()])+")" - - txn.execute(sql, new_rule.values()) + self._simple_insert_txn( + txn, + table=PushRuleTable.table_name, + values=new_rule, + ) def _add_push_rule_highest_priority_txn(self, txn, user_name, priority_class, **kwargs): @@ -176,18 +173,17 @@ class PushRuleStore(SQLBaseStore): new_prio = highest_prio + 1 # and insert the new rule - new_rule = copy.copy(kwargs) - if 'id' in new_rule: - del new_rule['id'] + new_rule = kwargs + new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) new_rule['user_name'] = user_name new_rule['priority_class'] = priority_class new_rule['priority'] = new_prio - sql = "INSERT INTO "+PushRuleTable.table_name+" (" - sql += ",".join(new_rule.keys())+") VALUES (" - sql += ", ".join(["?" for _ in new_rule.keys()])+")" - - txn.execute(sql, new_rule.values()) + self._simple_insert_txn( + txn, + table=PushRuleTable.table_name, + values=new_rule, + ) @defer.inlineCallbacks def delete_push_rule(self, user_name, rule_id): |