diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 35ec7e8cef..9dbad2fd5f 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -65,32 +65,20 @@ class PushRuleStore(SQLBaseStore):
if not user_ids:
defer.returnValue({})
- batch_size = 100
-
- def f(txn, user_ids_to_fetch):
- sql = (
- "SELECT pr.*"
- " FROM push_rules AS pr"
- " LEFT JOIN push_rules_enable AS pre"
- " ON pr.user_name = pre.user_name AND pr.rule_id = pre.rule_id"
- " WHERE pr.user_name"
- " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")"
- " AND (pre.enabled IS NULL OR pre.enabled = 1)"
- " ORDER BY pr.user_name, pr.priority_class DESC, pr.priority DESC"
- )
- txn.execute(sql, user_ids_to_fetch)
- return self.cursor_to_dict(txn)
-
results = {}
- chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)]
- for batch_user_ids in chunks:
- rows = yield self.runInteraction(
- "bulk_get_push_rules", f, batch_user_ids
- )
+ rows = yield self._simple_select_many_batch(
+ table="push_rules",
+ column="user_name",
+ iterable=user_ids,
+ retcols=("*",),
+ desc="bulk_get_push_rules",
+ )
+
+ rows.sort(key=lambda e: (-e["priority_class"], -e["priority"]))
- for row in rows:
- results.setdefault(row['user_name'], []).append(row)
+ for row in rows:
+ results.setdefault(row['user_name'], []).append(row)
defer.returnValue(results)
@defer.inlineCallbacks
@@ -98,62 +86,52 @@ class PushRuleStore(SQLBaseStore):
if not user_ids:
defer.returnValue({})
- batch_size = 100
-
- def f(txn, user_ids_to_fetch):
- sql = (
- "SELECT user_name, rule_id, enabled"
- " FROM push_rules_enable"
- " WHERE user_name"
- " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")"
- )
- txn.execute(sql, user_ids_to_fetch)
- return self.cursor_to_dict(txn)
-
results = {}
- chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)]
- for batch_user_ids in chunks:
- rows = yield self.runInteraction(
- "bulk_get_push_rules_enabled", f, batch_user_ids
- )
-
- for row in rows:
- results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
+ rows = yield self._simple_select_many_batch(
+ table="push_rules_enable",
+ column="user_name",
+ iterable=user_ids,
+ retcols=("user_name", "rule_id", "enabled",),
+ desc="bulk_get_push_rules_enabled",
+ )
+ for row in rows:
+ results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
defer.returnValue(results)
@defer.inlineCallbacks
- def add_push_rule(self, before, after, **kwargs):
- vals = kwargs
- if 'conditions' in vals:
- vals['conditions'] = json.dumps(vals['conditions'])
- if 'actions' in vals:
- vals['actions'] = json.dumps(vals['actions'])
-
- # we could check the rest of the keys are valid column names
- # but sqlite will do that anyway so I think it's just pointless.
- vals.pop("id", None)
-
- if before or after:
- ret = yield self.runInteraction(
- "_add_push_rule_relative_txn",
- self._add_push_rule_relative_txn,
- before=before,
- after=after,
- **vals
- )
- defer.returnValue(ret)
- else:
- ret = yield self.runInteraction(
- "_add_push_rule_highest_priority_txn",
- self._add_push_rule_highest_priority_txn,
- **vals
- )
- defer.returnValue(ret)
-
- def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
- after = kwargs.pop("after", None)
- relative_to_rule = kwargs.pop("before", after)
+ def add_push_rule(
+ self, user_id, rule_id, priority_class, conditions, actions,
+ before=None, after=None
+ ):
+ conditions_json = json.dumps(conditions)
+ actions_json = json.dumps(actions)
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ if before or after:
+ yield self.runInteraction(
+ "_add_push_rule_relative_txn",
+ self._add_push_rule_relative_txn,
+ stream_id, event_stream_ordering, user_id, rule_id, priority_class,
+ conditions_json, actions_json, before, after,
+ )
+ else:
+ yield self.runInteraction(
+ "_add_push_rule_highest_priority_txn",
+ self._add_push_rule_highest_priority_txn,
+ stream_id, event_stream_ordering, user_id, rule_id, priority_class,
+ conditions_json, actions_json,
+ )
+
+ def _add_push_rule_relative_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
+ conditions_json, actions_json, before, after
+ ):
+ # Lock the table since otherwise we'll have annoying races between the
+ # SELECT here and the UPSERT below.
+ self.database_engine.lock_table(txn, "push_rules")
+
+ relative_to_rule = before or after
res = self._simple_select_one_txn(
txn,
@@ -171,69 +149,45 @@ class PushRuleStore(SQLBaseStore):
"before/after rule not found: %s" % (relative_to_rule,)
)
- priority_class = res["priority_class"]
+ base_priority_class = res["priority_class"]
base_rule_priority = res["priority"]
- if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
+ if base_priority_class != priority_class:
raise InconsistentRuleException(
"Given priority class does not match class of relative rule"
)
- new_rule = kwargs
- new_rule.pop("before", None)
- new_rule.pop("after", None)
- new_rule['priority_class'] = priority_class
- new_rule['user_name'] = user_id
- new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
-
- # check if the priority before/after is free
- new_rule_priority = base_rule_priority
- if after:
- new_rule_priority -= 1
+ if before:
+ # Higher priority rules are executed first, So adding a rule before
+ # a rule means giving it a higher priority than that rule.
+ new_rule_priority = base_rule_priority + 1
else:
- new_rule_priority += 1
-
- new_rule['priority'] = new_rule_priority
+ # We increment the priority of the existing rules to make space for
+ # the new rule. Therefore if we want this rule to appear after
+ # an existing rule we give it the priority of the existing rule,
+ # and then increment the priority of the existing rule.
+ new_rule_priority = base_rule_priority
sql = (
- "SELECT COUNT(*) FROM push_rules"
- " WHERE user_name = ? AND priority_class = ? AND priority = ?"
+ "UPDATE push_rules SET priority = priority + 1"
+ " WHERE user_name = ? AND priority_class = ? AND priority >= ?"
)
- txn.execute(sql, (user_id, priority_class, new_rule_priority))
- res = txn.fetchall()
- num_conflicting = res[0][0]
-
- # if there are conflicting rules, bump everything
- if num_conflicting:
- sql = "UPDATE push_rules SET priority = priority "
- if after:
- sql += "-1"
- else:
- sql += "+1"
- sql += " WHERE user_name = ? AND priority_class = ? AND priority "
- if after:
- sql += "<= ?"
- else:
- sql += ">= ?"
-
- txn.execute(sql, (user_id, priority_class, new_rule_priority))
- txn.call_after(
- self.get_push_rules_for_user.invalidate, (user_id,)
- )
+ txn.execute(sql, (user_id, priority_class, new_rule_priority))
- txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, (user_id,)
+ self._upsert_push_rule_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
+ new_rule_priority, conditions_json, actions_json,
)
- self._simple_insert_txn(
- txn,
- table="push_rules",
- values=new_rule,
- )
+ def _add_push_rule_highest_priority_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
+ conditions_json, actions_json
+ ):
+ # Lock the table since otherwise we'll have annoying races between the
+ # SELECT here and the UPSERT below.
+ self.database_engine.lock_table(txn, "push_rules")
- def _add_push_rule_highest_priority_txn(self, txn, user_id,
- priority_class, **kwargs):
# find the highest priority rule in that class
sql = (
"SELECT COUNT(*), MAX(priority) FROM push_rules"
@@ -247,26 +201,61 @@ class PushRuleStore(SQLBaseStore):
if how_many > 0:
new_prio = highest_prio + 1
- # and insert the new rule
- new_rule = kwargs
- new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
- new_rule['user_name'] = user_id
- new_rule['priority_class'] = priority_class
- new_rule['priority'] = new_prio
-
- txn.call_after(
- self.get_push_rules_for_user.invalidate, (user_id,)
- )
- txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, (user_id,)
+ self._upsert_push_rule_txn(
+ txn,
+ stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio,
+ conditions_json, actions_json,
)
- self._simple_insert_txn(
- txn,
- table="push_rules",
- values=new_rule,
+ def _upsert_push_rule_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
+ priority, conditions_json, actions_json, update_stream=True
+ ):
+ """Specialised version of _simple_upsert_txn that picks a push_rule_id
+ using the _push_rule_id_gen if it needs to insert the rule. It assumes
+ that the "push_rules" table is locked"""
+
+ sql = (
+ "UPDATE push_rules"
+ " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
+ " WHERE user_name = ? AND rule_id = ?"
)
+ txn.execute(sql, (
+ priority_class, priority, conditions_json, actions_json,
+ user_id, rule_id,
+ ))
+
+ if txn.rowcount == 0:
+ # We didn't update a row with the given rule_id so insert one
+ push_rule_id = self._push_rule_id_gen.get_next()
+
+ self._simple_insert_txn(
+ txn,
+ table="push_rules",
+ values={
+ "id": push_rule_id,
+ "user_name": user_id,
+ "rule_id": rule_id,
+ "priority_class": priority_class,
+ "priority": priority,
+ "conditions": conditions_json,
+ "actions": actions_json,
+ },
+ )
+
+ if update_stream:
+ self._insert_push_rules_update_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id,
+ op="ADD",
+ data={
+ "priority_class": priority_class,
+ "priority": priority,
+ "conditions": conditions_json,
+ "actions": actions_json,
+ }
+ )
+
@defer.inlineCallbacks
def delete_push_rule(self, user_id, rule_id):
"""
@@ -278,26 +267,38 @@ class PushRuleStore(SQLBaseStore):
user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
- yield self._simple_delete_one(
- "push_rules",
- {'user_name': user_id, 'rule_id': rule_id},
- desc="delete_push_rule",
- )
+ def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+ self._simple_delete_one_txn(
+ txn,
+ "push_rules",
+ {'user_name': user_id, 'rule_id': rule_id},
+ )
- self.get_push_rules_for_user.invalidate((user_id,))
- self.get_push_rules_enabled_for_user.invalidate((user_id,))
+ self._insert_push_rules_update_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id,
+ op="DELETE"
+ )
+
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ yield self.runInteraction(
+ "delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering
+ )
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_id, rule_id, enabled):
- ret = yield self.runInteraction(
- "_set_push_rule_enabled_txn",
- self._set_push_rule_enabled_txn,
- user_id, rule_id, enabled
- )
- defer.returnValue(ret)
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ yield self.runInteraction(
+ "_set_push_rule_enabled_txn",
+ self._set_push_rule_enabled_txn,
+ stream_id, event_stream_ordering, user_id, rule_id, enabled
+ )
- def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
- new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
+ def _set_push_rule_enabled_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
+ ):
+ new_id = self._push_rules_enable_id_gen.get_next()
self._simple_upsert_txn(
txn,
"push_rules_enable",
@@ -305,12 +306,109 @@ class PushRuleStore(SQLBaseStore):
{'enabled': 1 if enabled else 0},
{'id': new_id},
)
+
+ self._insert_push_rules_update_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id,
+ op="ENABLE" if enabled else "DISABLE"
+ )
+
+ @defer.inlineCallbacks
+ def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+ actions_json = json.dumps(actions)
+
+ def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
+ if is_default_rule:
+ # Add a dummy rule to the rules table with the user specified
+ # actions.
+ priority_class = -1
+ priority = 1
+ self._upsert_push_rule_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id,
+ priority_class, priority, "[]", actions_json,
+ update_stream=False
+ )
+ else:
+ self._simple_update_one_txn(
+ txn,
+ "push_rules",
+ {'user_name': user_id, 'rule_id': rule_id},
+ {'actions': actions_json},
+ )
+
+ self._insert_push_rules_update_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id,
+ op="ACTIONS", data={"actions": actions_json}
+ )
+
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ yield self.runInteraction(
+ "set_push_rule_actions", set_push_rule_actions_txn,
+ stream_id, event_stream_ordering
+ )
+
+ def _insert_push_rules_update_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
+ ):
+ values = {
+ "stream_id": stream_id,
+ "event_stream_ordering": event_stream_ordering,
+ "user_id": user_id,
+ "rule_id": rule_id,
+ "op": op,
+ }
+ if data is not None:
+ values.update(data)
+
+ self._simple_insert_txn(txn, "push_rules_stream", values=values)
+
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
+ txn.call_after(
+ self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
+ )
+
+ def get_all_push_rule_updates(self, last_id, current_id, limit):
+ """Get all the push rules changes that have happend on the server"""
+ def get_all_push_rule_updates_txn(txn):
+ sql = (
+ "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
+ " op, priority_class, priority, conditions, actions"
+ " FROM push_rules_stream"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_push_rule_updates", get_all_push_rule_updates_txn
+ )
+
+ def get_push_rules_stream_token(self):
+ """Get the position of the push rules stream.
+ Returns a pair of a stream id for the push_rules stream and the
+ room stream ordering it corresponds to."""
+ return self._push_rules_stream_id_gen.get_max_token()
+
+ def have_push_rules_changed_for_user(self, user_id, last_id):
+ if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+ return defer.succeed(False)
+ else:
+ def have_push_rules_changed_txn(txn):
+ sql = (
+ "SELECT COUNT(stream_id) FROM push_rules_stream"
+ " WHERE user_id = ? AND ? < stream_id"
+ )
+ txn.execute(sql, (user_id, last_id))
+ count, = txn.fetchone()
+ return bool(count)
+ return self.runInteraction(
+ "have_push_rules_changed", have_push_rules_changed_txn
+ )
class RuleNotFoundException(Exception):
|