diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index f257721ea3..e2d7b52569 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -45,7 +45,7 @@ from .search import SearchStore
from .tags import TagsStore
from .account_data import AccountDataStore
-from util.id_generators import IdGenerator, StreamIdGenerator
+from util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -122,6 +122,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+ self._push_rules_stream_id_gen = ChainedIdGenerator(
+ self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
+ )
events_max = self._stream_id_gen.get_max_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2e97ac84a8..7dc67ecd57 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -770,18 +770,29 @@ class SQLBaseStore(object):
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
+ return self.runInteraction(
+ desc, self._simple_delete_one_txn, table, keyvalues
+ )
+
+ @staticmethod
+ def _simple_delete_one_txn(txn, table, keyvalues):
+ """Executes a DELETE query on the named table, expecting to delete a
+ single row.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- def func(txn):
- txn.execute(sql, keyvalues.values())
- if txn.rowcount == 0:
- raise StoreError(404, "No row found")
- if txn.rowcount > 1:
- raise StoreError(500, "more than one row matched")
- return self.runInteraction(desc, func)
+ txn.execute(sql, keyvalues.values())
+ if txn.rowcount == 0:
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "more than one row matched")
@staticmethod
def _simple_delete_txn(txn, table, keyvalues):
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 56e69495b1..f3ebd49492 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -99,30 +99,31 @@ class PushRuleStore(SQLBaseStore):
results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
defer.returnValue(results)
+ @defer.inlineCallbacks
def add_push_rule(
self, user_id, rule_id, priority_class, conditions, actions,
before=None, after=None
):
conditions_json = json.dumps(conditions)
actions_json = json.dumps(actions)
-
- if before or after:
- return self.runInteraction(
- "_add_push_rule_relative_txn",
- self._add_push_rule_relative_txn,
- user_id, rule_id, priority_class,
- conditions_json, actions_json, before, after,
- )
- else:
- return self.runInteraction(
- "_add_push_rule_highest_priority_txn",
- self._add_push_rule_highest_priority_txn,
- user_id, rule_id, priority_class,
- conditions_json, actions_json,
- )
+ with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering):
+ if before or after:
+ yield self.runInteraction(
+ "_add_push_rule_relative_txn",
+ self._add_push_rule_relative_txn,
+ stream_id, stream_ordering, user_id, rule_id, priority_class,
+ conditions_json, actions_json, before, after,
+ )
+ else:
+ yield self.runInteraction(
+ "_add_push_rule_highest_priority_txn",
+ self._add_push_rule_highest_priority_txn,
+ stream_id, stream_ordering, user_id, rule_id, priority_class,
+ conditions_json, actions_json,
+ )
def _add_push_rule_relative_txn(
- self, txn, user_id, rule_id, priority_class,
+ self, txn, stream_id, stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json, before, after
):
# Lock the table since otherwise we'll have annoying races between the
@@ -174,12 +175,12 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_id, priority_class, new_rule_priority))
self._upsert_push_rule_txn(
- txn, user_id, rule_id, priority_class, new_rule_priority,
- conditions_json, actions_json,
+ txn, stream_id, stream_ordering, user_id, rule_id, priority_class,
+ new_rule_priority, conditions_json, actions_json,
)
def _add_push_rule_highest_priority_txn(
- self, txn, user_id, rule_id, priority_class,
+ self, txn, stream_id, stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json
):
# Lock the table since otherwise we'll have annoying races between the
@@ -201,13 +202,13 @@ class PushRuleStore(SQLBaseStore):
self._upsert_push_rule_txn(
txn,
- user_id, rule_id, priority_class, new_prio,
+ stream_id, stream_ordering, user_id, rule_id, priority_class, new_prio,
conditions_json, actions_json,
)
def _upsert_push_rule_txn(
- self, txn, user_id, rule_id, priority_class,
- priority, conditions_json, actions_json
+ self, txn, stream_id, stream_ordering, user_id, rule_id, priority_class,
+ priority, conditions_json, actions_json, update_stream=True
):
"""Specialised version of _simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes
@@ -242,6 +243,23 @@ class PushRuleStore(SQLBaseStore):
},
)
+ if update_stream:
+ self._simple_insert_txn(
+ txn,
+ table="push_rules_stream",
+ values={
+ "stream_id": stream_id,
+ "stream_ordering": stream_ordering,
+ "user_id": user_id,
+ "rule_id": rule_id,
+ "op": "ADD",
+ "priority_class": priority_class,
+ "priority": priority,
+ "conditions": conditions_json,
+ "actions": actions_json,
+ }
+ )
+
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,)
)
@@ -260,25 +278,47 @@ class PushRuleStore(SQLBaseStore):
user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
- yield self._simple_delete_one(
- "push_rules",
- {'user_name': user_id, 'rule_id': rule_id},
- desc="delete_push_rule",
- )
+ def delete_push_rule_txn(txn, stream_id, stream_ordering):
+ self._simple_delete_one_txn(
+ txn,
+ "push_rules",
+ {'user_name': user_id, 'rule_id': rule_id},
+ )
+ self._simple_insert_txn(
+ txn,
+ table="push_rules_stream",
+ values={
+ "stream_id": stream_id,
+ "stream_ordering": stream_ordering,
+ "user_id": user_id,
+ "rule_id": rule_id,
+ "op": "DELETE",
+ }
+ )
+ txn.call_after(
+ self.get_push_rules_for_user.invalidate, (user_id,)
+ )
+ txn.call_after(
+ self.get_push_rules_enabled_for_user.invalidate, (user_id,)
+ )
- self.get_push_rules_for_user.invalidate((user_id,))
- self.get_push_rules_enabled_for_user.invalidate((user_id,))
+ with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering):
+ yield self.runInteraction(
+ "delete_push_rule", delete_push_rule_txn, stream_id, stream_ordering
+ )
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_id, rule_id, enabled):
- ret = yield self.runInteraction(
- "_set_push_rule_enabled_txn",
- self._set_push_rule_enabled_txn,
- user_id, rule_id, enabled
- )
- defer.returnValue(ret)
+ with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering):
+ yield self.runInteraction(
+ "_set_push_rule_enabled_txn",
+ self._set_push_rule_enabled_txn,
+ stream_id, stream_ordering, user_id, rule_id, enabled
+ )
- def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
+ def _set_push_rule_enabled_txn(
+ self, txn, stream_id, stream_ordering, user_id, rule_id, enabled
+ ):
new_id = self._push_rules_enable_id_gen.get_next()
self._simple_upsert_txn(
txn,
@@ -287,6 +327,19 @@ class PushRuleStore(SQLBaseStore):
{'enabled': 1 if enabled else 0},
{'id': new_id},
)
+
+ self._simple_insert_txn(
+ txn,
+ "push_rules_stream",
+ values={
+ "stream_id": stream_id,
+ "stream_ordering": stream_ordering,
+ "user_id": user_id,
+ "rule_id": rule_id,
+ "op": "ENABLE" if enabled else "DISABLE",
+ }
+ )
+
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,)
)
@@ -294,18 +347,20 @@ class PushRuleStore(SQLBaseStore):
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
+ @defer.inlineCallbacks
def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
actions_json = json.dumps(actions)
- def set_push_rule_actions_txn(txn):
+ def set_push_rule_actions_txn(txn, stream_id, stream_ordering):
if is_default_rule:
# Add a dummy rule to the rules table with the user specified
# actions.
priority_class = -1
priority = 1
self._upsert_push_rule_txn(
- txn, user_id, rule_id, priority_class, priority,
- "[]", actions_json
+ txn, stream_id, stream_ordering, user_id, rule_id,
+ priority_class, priority, "[]", actions_json,
+ update_stream=False
)
else:
self._simple_update_one_txn(
@@ -315,8 +370,46 @@ class PushRuleStore(SQLBaseStore):
{'actions': actions_json},
)
+ self._simple_insert_txn(
+ txn,
+ "push_rules_stream",
+ values={
+ "stream_id": stream_id,
+ "stream_ordering": stream_ordering,
+ "user_id": user_id,
+ "rule_id": rule_id,
+ "op": "ACTIONS",
+ "actions": actions_json,
+ }
+ )
+
+ txn.call_after(
+ self.get_push_rules_for_user.invalidate, (user_id,)
+ )
+ txn.call_after(
+ self.get_push_rules_enabled_for_user.invalidate, (user_id,)
+ )
+
+ with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering):
+ yield self.runInteraction(
+ "set_push_rule_actions", set_push_rule_actions_txn,
+ stream_id, stream_ordering
+ )
+
+ def get_all_push_rule_updates(self, last_id, current_id, limit):
+ """Get all the push rules changes that have happend on the server"""
+ def get_all_push_rule_updates_txn(txn):
+ sql = (
+ "SELECT stream_id, stream_ordering, user_id, rule_id,"
+ " op, priority_class, priority, conditions, actions"
+ " FROM push_rules_stream"
+ " WHERE ? < stream_id and stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
return self.runInteraction(
- "set_push_rule_actions", set_push_rule_actions_txn,
+ "get_all_push_rule_updates", get_all_push_rule_updates_txn
)
diff --git a/synapse/storage/schema/delta/30/push_rule_stream.sql b/synapse/storage/schema/delta/30/push_rule_stream.sql
new file mode 100644
index 0000000000..e8418bb35f
--- /dev/null
+++ b/synapse/storage/schema/delta/30/push_rule_stream.sql
@@ -0,0 +1,38 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+
+CREATE TABLE push_rules_stream(
+ stream_id BIGINT NOT NULL,
+ stream_ordering BIGINT NOT NULL,
+ user_id TEXT NOT NULL,
+ rule_id TEXT NOT NULL,
+ op TEXT NOT NULL, -- One of "ENABLE", "DISABLE", "ACTIONS", "ADD", "DELETE"
+ priority_class SMALLINT,
+ priority INTEGER,
+ conditions TEXT,
+ actions TEXT
+);
+
+-- The extra data for each operation is:
+-- * ENABLE, DISABLE, DELETE: []
+-- * ACTIONS: ["actions"]
+-- * ADD: ["priority_class", "priority", "actions", "conditions"]
+
+-- Index for replication queries.
+CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id);
+-- Index for /sync queries.
+CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id);
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index efe3f68e6e..af425ba9a4 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -20,23 +20,21 @@ import threading
class IdGenerator(object):
def __init__(self, db_conn, table, column):
- self.table = table
- self.column = column
self._lock = threading.Lock()
- cur = db_conn.cursor()
- self._next_id = self._load_next_id(cur)
- cur.close()
-
- def _load_next_id(self, txn):
- txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,))
- val, = txn.fetchone()
- return val + 1 if val else 1
+ self._next_id = _load_max_id(db_conn, table, column)
def get_next(self):
with self._lock:
- i = self._next_id
self._next_id += 1
- return i
+ return self._next_id
+
+
+def _load_max_id(db_conn, table, column):
+ cur = db_conn.cursor()
+ cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+ val, = cur.fetchone()
+ cur.close()
+ return val if val else 1
class StreamIdGenerator(object):
@@ -52,23 +50,10 @@ class StreamIdGenerator(object):
# ... persist event ...
"""
def __init__(self, db_conn, table, column):
- self.table = table
- self.column = column
-
self._lock = threading.Lock()
-
- cur = db_conn.cursor()
- self._current_max = self._load_current_max(cur)
- cur.close()
-
+ self._current_max = _load_max_id(db_conn, table, column)
self._unfinished_ids = deque()
- def _load_current_max(self, txn):
- txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
- rows = txn.fetchall()
- val, = rows[0]
- return int(val) if val else 1
-
def get_next(self):
"""
Usage:
@@ -124,3 +109,50 @@ class StreamIdGenerator(object):
return self._unfinished_ids[0] - 1
return self._current_max
+
+
+class ChainedIdGenerator(object):
+ """Used to generate new stream ids where the stream must be kept in sync
+ with another stream. It generates pairs of IDs, the first element is an
+ integer ID for this stream, the second element is the ID for the stream
+ that this stream needs to be kept in sync with."""
+
+ def __init__(self, chained_generator, db_conn, table, column):
+ self.chained_generator = chained_generator
+ self._lock = threading.Lock()
+ self._current_max = _load_max_id(db_conn, table, column)
+ self._unfinished_ids = deque()
+
+ def get_next(self):
+ """
+ Usage:
+ with stream_id_gen.get_next() as (stream_id, chained_id):
+ # ... persist event ...
+ """
+ with self._lock:
+ self._current_max += 1
+ next_id = self._current_max
+ chained_id = self.chained_generator.get_max_token()
+
+ self._unfinished_ids.append((next_id, chained_id))
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield (next_id, chained_id)
+ finally:
+ with self._lock:
+ self._unfinished_ids.remove((next_id, chained_id))
+
+ return manager()
+
+ def get_max_token(self):
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
+ """
+ with self._lock:
+ if self._unfinished_ids:
+ stream_id, chained_id = self._unfinished_ids[0]
+ return (stream_id - 1, chained_id)
+
+ return (self._current_max, self.chained_generator.get_max_token())
|