diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 9be1d12fac..f257721ea3 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -115,13 +115,13 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "presence_stream", "stream_id"
)
- self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
- self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
- self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
- self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
- self._pushers_id_gen = IdGenerator("pushers", "id", self)
- self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
- self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+ self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
+ self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
+ self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+ self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
+ self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id")
+ self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+ self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
events_max = self._stream_id_gen.get_max_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 91cbf399b6..faddefe219 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -83,8 +83,40 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn
)
- def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None):
- """Get all the client account_data for a that's changed.
+ def get_all_updated_account_data(self, last_global_id, last_room_id,
+ current_id, limit):
+ """Get all the client account_data that has changed on the server
+ Args:
+ last_global_id(int): The position to fetch from for top level data
+ last_room_id(int): The position to fetch from for per room data
+ current_id(int): The position to fetch up to.
+ Returns:
+ A deferred pair of lists of tuples of stream_id int, user_id string,
+ room_id string, type string, and content string.
+ """
+ def get_updated_account_data_txn(txn):
+ sql = (
+ "SELECT stream_id, user_id, account_data_type, content"
+ " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_global_id, current_id, limit))
+ global_results = txn.fetchall()
+
+ sql = (
+ "SELECT stream_id, user_id, room_id, account_data_type, content"
+ " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_room_id, current_id, limit))
+ room_results = txn.fetchall()
+ return (global_results, room_results)
+ return self.runInteraction(
+ "get_all_updated_account_data_txn", get_updated_account_data_txn
+ )
+
+ def get_updated_account_data_for_user(self, user_id, stream_id):
+ """Get all the client account_data for a that's changed for a user
Args:
user_id(str): The user to get the account_data for.
@@ -163,12 +195,12 @@ class AccountDataStore(SQLBaseStore):
)
self._update_max_stream_id(txn, next_id)
- with (yield self._account_data_id_gen.get_next(self)) as next_id:
+ with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction(
"add_room_account_data", add_account_data_txn, next_id
)
- result = yield self._account_data_id_gen.get_max_token()
+ result = self._account_data_id_gen.get_max_token()
defer.returnValue(result)
@defer.inlineCallbacks
@@ -202,12 +234,12 @@ class AccountDataStore(SQLBaseStore):
)
self._update_max_stream_id(txn, next_id)
- with (yield self._account_data_id_gen.get_next(self)) as next_id:
+ with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction(
"add_user_account_data", add_account_data_txn, next_id
)
- result = yield self._account_data_id_gen.get_max_token()
+ result = self._account_data_id_gen.get_max_token()
defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 1dd3236829..60936500d8 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -75,8 +75,8 @@ class EventsStore(SQLBaseStore):
yield stream_orderings
stream_ordering_manager = stream_ordering_manager()
else:
- stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
- self, len(events_and_contexts)
+ stream_ordering_manager = self._stream_id_gen.get_next_mult(
+ len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings:
@@ -109,7 +109,7 @@ class EventsStore(SQLBaseStore):
stream_ordering = self.min_stream_token
if stream_ordering is None:
- stream_ordering_manager = yield self._stream_id_gen.get_next(self)
+ stream_ordering_manager = self._stream_id_gen.get_next()
else:
@contextmanager
def stream_ordering_manager():
@@ -1064,3 +1064,48 @@ class EventsStore(SQLBaseStore):
yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
defer.returnValue(result)
+
+ def get_current_backfill_token(self):
+ """The current minimum token that backfilled events have reached"""
+
+ # TODO: Fix race with the persit_event txn by using one of the
+ # stream id managers
+ return -self.min_stream_token
+
+ def get_all_new_events(self, last_backfill_id, last_forward_id,
+ current_backfill_id, current_forward_id, limit):
+ """Get all the new events that have arrived at the server either as
+ new events or as backfilled events"""
+ def get_all_new_events_txn(txn):
+ sql = (
+ "SELECT e.stream_ordering, ej.internal_metadata, ej.json"
+ " FROM events as e"
+ " JOIN event_json as ej"
+ " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
+ " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
+ " ORDER BY e.stream_ordering ASC"
+ " LIMIT ?"
+ )
+ if last_forward_id != current_forward_id:
+ txn.execute(sql, (last_forward_id, current_forward_id, limit))
+ new_forward_events = txn.fetchall()
+ else:
+ new_forward_events = []
+
+ sql = (
+ "SELECT -e.stream_ordering, ej.internal_metadata, ej.json"
+ " FROM events as e"
+ " JOIN event_json as ej"
+ " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
+ " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
+ " ORDER BY e.stream_ordering DESC"
+ " LIMIT ?"
+ )
+ if last_backfill_id != current_backfill_id:
+ txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
+ new_backfill_events = txn.fetchall()
+ else:
+ new_backfill_events = []
+
+ return (new_forward_events, new_backfill_events)
+ return self.runInteraction("get_all_new_events", get_all_new_events_txn)
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 3ef91d34db..4cec31e316 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -58,8 +58,8 @@ class UserPresenceState(namedtuple("UserPresenceState",
class PresenceStore(SQLBaseStore):
@defer.inlineCallbacks
def update_presence(self, presence_states):
- stream_ordering_manager = yield self._presence_id_gen.get_next_mult(
- self, len(presence_states)
+ stream_ordering_manager = self._presence_id_gen.get_next_mult(
+ len(presence_states)
)
with stream_ordering_manager as stream_orderings:
@@ -115,6 +115,22 @@ class PresenceStore(SQLBaseStore):
args
)
+ def get_all_presence_updates(self, last_id, current_id):
+ def get_all_presence_updates_txn(txn):
+ sql = (
+ "SELECT stream_id, user_id, state, last_active_ts,"
+ " last_federation_update_ts, last_user_sync_ts, status_msg,"
+ " currently_active"
+ " FROM presence_stream"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ )
+ txn.execute(sql, (last_id, current_id))
+ return txn.fetchall()
+
+ return self.runInteraction(
+ "get_all_presence_updates", get_all_presence_updates_txn
+ )
+
@defer.inlineCallbacks
def get_presence_for_users(self, user_ids):
rows = yield self._simple_select_many_batch(
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index bb5c14d912..56e69495b1 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -226,7 +226,7 @@ class PushRuleStore(SQLBaseStore):
if txn.rowcount == 0:
# We didn't update a row with the given rule_id so insert one
- push_rule_id = self._push_rule_id_gen.get_next_txn(txn)
+ push_rule_id = self._push_rule_id_gen.get_next()
self._simple_insert_txn(
txn,
@@ -279,7 +279,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(ret)
def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
- new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
+ new_id = self._push_rules_enable_id_gen.get_next()
self._simple_upsert_txn(
txn,
"push_rules_enable",
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index c23648cdbc..7693ab9082 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -84,7 +84,7 @@ class PusherStore(SQLBaseStore):
app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data, profile_tag=""):
try:
- next_id = yield self._pushers_id_gen.get_next()
+ next_id = self._pushers_id_gen.get_next()
yield self._simple_upsert(
"pushers",
dict(
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index a7343c97f7..dbc074d6b5 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -330,7 +330,7 @@ class ReceiptsStore(SQLBaseStore):
"insert_receipt_conv", graph_to_linear
)
- stream_id_manager = yield self._receipts_id_gen.get_next(self)
+ stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
have_persisted = yield self.runInteraction(
"insert_linearized_receipt",
@@ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data
)
- max_persisted_id = yield self._stream_id_gen.get_max_token()
+ max_persisted_id = self._stream_id_gen.get_max_token()
defer.returnValue((stream_id, max_persisted_id))
@@ -390,3 +390,19 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data),
}
)
+
+ def get_all_updated_receipts(self, last_id, current_id, limit):
+ def get_all_updated_receipts_txn(txn):
+ sql = (
+ "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
+ " FROM receipts_linearized"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_updated_receipts", get_all_updated_receipts_txn
+ )
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 03a9b66e4a..ad1157f979 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -40,7 +40,7 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if there was a problem adding this.
"""
- next_id = yield self._access_tokens_id_gen.get_next()
+ next_id = self._access_tokens_id_gen.get_next()
yield self._simple_insert(
"access_tokens",
@@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if there was a problem adding this.
"""
- next_id = yield self._refresh_tokens_id_gen.get_next()
+ next_id = self._refresh_tokens_id_gen.get_next()
yield self._simple_insert(
"refresh_tokens",
@@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore):
def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
now = int(self.clock.time())
- next_id = self._access_tokens_id_gen.get_next_txn(txn)
+ next_id = self._access_tokens_id_gen.get_next()
try:
if was_guest:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 372b540002..8ed8a21b0a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -83,7 +83,7 @@ class StateStore(SQLBaseStore):
if event.is_state():
state_events[(event.type, event.state_key)] = event
- state_group = self._state_groups_id_gen.get_next_txn(txn)
+ state_group = self._state_groups_id_gen.get_next()
self._simple_insert_txn(
txn,
table="state_groups",
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 9551aa9739..a0e6b42b30 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -59,6 +59,59 @@ class TagsStore(SQLBaseStore):
return deferred
@defer.inlineCallbacks
+ def get_all_updated_tags(self, last_id, current_id, limit):
+ """Get all the client tags that have changed on the server
+ Args:
+ last_id(int): The position to fetch from.
+ current_id(int): The position to fetch up to.
+ Returns:
+ A deferred list of tuples of stream_id int, user_id string,
+ room_id string, tag string and content string.
+ """
+ def get_all_updated_tags_txn(txn):
+ sql = (
+ "SELECT stream_id, user_id, room_id"
+ " FROM room_tags_revisions as r"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+
+ tag_ids = yield self.runInteraction(
+ "get_all_updated_tags", get_all_updated_tags_txn
+ )
+
+ def get_tag_content(txn, tag_ids):
+ sql = (
+ "SELECT tag, content"
+ " FROM room_tags"
+ " WHERE user_id=? AND room_id=?"
+ )
+ results = []
+ for stream_id, user_id, room_id in tag_ids:
+ txn.execute(sql, (user_id, room_id))
+ tags = []
+ for tag, content in txn.fetchall():
+ tags.append(json.dumps(tag) + ":" + content)
+ tag_json = "{" + ",".join(tags) + "}"
+ results.append((stream_id, user_id, room_id, tag_json))
+
+ return results
+
+ batch_size = 50
+ results = []
+ for i in xrange(0, len(tag_ids), batch_size):
+ tags = yield self.runInteraction(
+ "get_all_updated_tag_content",
+ get_tag_content,
+ tag_ids[i:i + batch_size],
+ )
+ results.extend(tags)
+
+ defer.returnValue(results)
+
+ @defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
"""Get all the tags for the rooms where the tags have changed since the
given version
@@ -142,12 +195,12 @@ class TagsStore(SQLBaseStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with (yield self._account_data_id_gen.get_next(self)) as next_id:
+ with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = yield self._account_data_id_gen.get_max_token()
+ result = self._account_data_id_gen.get_max_token()
defer.returnValue(result)
@defer.inlineCallbacks
@@ -164,12 +217,12 @@ class TagsStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with (yield self._account_data_id_gen.get_next(self)) as next_id:
+ with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = yield self._account_data_id_gen.get_max_token()
+ result = self._account_data_id_gen.get_max_token()
defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id):
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 4475c451c1..d338dfcf0a 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -117,7 +117,7 @@ class TransactionStore(SQLBaseStore):
def _prep_send_transaction(self, txn, transaction_id, destination,
origin_server_ts):
- next_id = self._transaction_id_gen.get_next_txn(txn)
+ next_id = self._transaction_id_gen.get_next()
# First we find out what the prev_txns should be.
# Since we know that we are only sending one transaction at a time,
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ef5e4a4668..efe3f68e6e 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -13,51 +13,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from collections import deque
import contextlib
import threading
class IdGenerator(object):
- def __init__(self, table, column, store):
+ def __init__(self, db_conn, table, column):
self.table = table
self.column = column
- self.store = store
self._lock = threading.Lock()
- self._next_id = None
+ cur = db_conn.cursor()
+ self._next_id = self._load_next_id(cur)
+ cur.close()
- @defer.inlineCallbacks
- def get_next(self):
- if self._next_id is None:
- yield self.store.runInteraction(
- "IdGenerator_%s" % (self.table,),
- self.get_next_txn,
- )
+ def _load_next_id(self, txn):
+ txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,))
+ val, = txn.fetchone()
+ return val + 1 if val else 1
+ def get_next(self):
with self._lock:
i = self._next_id
self._next_id += 1
- defer.returnValue(i)
-
- def get_next_txn(self, txn):
- with self._lock:
- if self._next_id:
- i = self._next_id
- self._next_id += 1
- return i
- else:
- txn.execute(
- "SELECT MAX(%s) FROM %s" % (self.column, self.table,)
- )
-
- val, = txn.fetchone()
- cur = val or 0
- cur += 1
- self._next_id = cur + 1
-
- return cur
+ return i
class StreamIdGenerator(object):
@@ -69,7 +48,7 @@ class StreamIdGenerator(object):
persistence of events can complete out of order.
Usage:
- with stream_id_gen.get_next_txn(txn) as stream_id:
+ with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
def __init__(self, db_conn, table, column):
@@ -79,15 +58,21 @@ class StreamIdGenerator(object):
self._lock = threading.Lock()
cur = db_conn.cursor()
- self._current_max = self._get_or_compute_current_max(cur)
+ self._current_max = self._load_current_max(cur)
cur.close()
self._unfinished_ids = deque()
- def get_next(self, store):
+ def _load_current_max(self, txn):
+ txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
+ rows = txn.fetchall()
+ val, = rows[0]
+ return int(val) if val else 1
+
+ def get_next(self):
"""
Usage:
- with yield stream_id_gen.get_next as stream_id:
+ with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -106,10 +91,10 @@ class StreamIdGenerator(object):
return manager()
- def get_next_mult(self, store, n):
+ def get_next_mult(self, n):
"""
Usage:
- with yield stream_id_gen.get_next(store, n) as stream_ids:
+ with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -139,13 +124,3 @@ class StreamIdGenerator(object):
return self._unfinished_ids[0] - 1
return self._current_max
-
- def _get_or_compute_current_max(self, txn):
- with self._lock:
- txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
- rows = txn.fetchall()
- val, = rows[0]
-
- self._current_max = int(val) if val else 1
-
- return self._current_max
|