diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index aaad38039e..f87e907cd8 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -88,15 +88,6 @@ class DataStore(RoomMemberStore, RoomStore,
self.hs = hs
self.database_engine = hs.database_engine
- cur = db_conn.cursor()
- try:
- cur.execute("SELECT MIN(stream_ordering) FROM events",)
- rows = cur.fetchall()
- self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
- self.min_stream_token = min(self.min_stream_token, -1)
- finally:
- cur.close()
-
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
@@ -105,6 +96,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering"
)
+ self._backfill_id_gen = StreamIdGenerator(
+ db_conn, "events", "stream_ordering", direction=-1
+ )
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
@@ -129,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")],
)
- events_max = self._stream_id_gen.get_max_token()
+ events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
@@ -145,7 +139,7 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max,
)
- account_max = self._account_data_id_gen.get_max_token()
+ account_max = self._account_data_id_gen.get_current_token()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max,
)
@@ -156,7 +150,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "presence_stream",
entity_column="user_id",
stream_column="stream_id",
- max_value=self._presence_id_gen.get_max_token(),
+ max_value=self._presence_id_gen.get_current_token(),
)
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", min_presence_val,
@@ -167,7 +161,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
- max_value=self._push_rules_stream_id_gen.get_max_token()[0],
+ max_value=self._push_rules_stream_id_gen.get_current_token()[0],
)
self.push_rules_stream_cache = StreamChangeCache(
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index faddefe219..7a7fbf1e52 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -200,7 +200,7 @@ class AccountDataStore(SQLBaseStore):
"add_room_account_data", add_account_data_txn, next_id
)
- result = self._account_data_id_gen.get_max_token()
+ result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
@defer.inlineCallbacks
@@ -239,7 +239,7 @@ class AccountDataStore(SQLBaseStore):
"add_user_account_data", add_account_data_txn, next_id
)
- result = self._account_data_id_gen.get_max_token()
+ result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 83279d65fa..4ab23c1597 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -24,7 +24,6 @@ from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
from canonicaljson import encode_canonical_json
-from contextlib import contextmanager
from collections import namedtuple
import logging
@@ -66,14 +65,9 @@ class EventsStore(SQLBaseStore):
return
if backfilled:
- start = self.min_stream_token - 1
- self.min_stream_token -= len(events_and_contexts) + 1
- stream_orderings = range(start, self.min_stream_token, -1)
-
- @contextmanager
- def stream_ordering_manager():
- yield stream_orderings
- stream_ordering_manager = stream_ordering_manager()
+ stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+ len(events_and_contexts)
+ )
else:
stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts)
@@ -130,7 +124,7 @@ class EventsStore(SQLBaseStore):
except _RollbackButIsFineException:
pass
- max_persisted_id = yield self._stream_id_gen.get_max_token()
+ max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks
@@ -1117,10 +1111,7 @@ class EventsStore(SQLBaseStore):
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
-
- # TODO: Fix race with the persit_event txn by using one of the
- # stream id managers
- return -self.min_stream_token
+ return -self._backfill_id_gen.get_current_token()
def get_all_new_events(self, last_backfill_id, last_forward_id,
current_backfill_id, current_forward_id, limit):
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 4cec31e316..59b4ef5ce6 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -68,7 +68,9 @@ class PresenceStore(SQLBaseStore):
self._update_presence_txn, stream_orderings, presence_states,
)
- defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token()))
+ defer.returnValue((
+ stream_orderings[-1], self._presence_id_gen.get_current_token()
+ ))
def _update_presence_txn(self, txn, stream_orderings, presence_states):
for stream_id, state in zip(stream_orderings, presence_states):
@@ -155,7 +157,7 @@ class PresenceStore(SQLBaseStore):
defer.returnValue([UserPresenceState(**row) for row in rows])
def get_current_presence_token(self):
- return self._presence_id_gen.get_max_token()
+ return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert(
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 9dbad2fd5f..d2bf7f2aec 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -392,7 +392,7 @@ class PushRuleStore(SQLBaseStore):
"""Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the
room stream ordering it corresponds to."""
- return self._push_rules_stream_id_gen.get_max_token()
+ return self._push_rules_stream_id_gen.get_current_token()
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 87b2ac5773..d1669c778a 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -78,7 +78,7 @@ class PusherStore(SQLBaseStore):
defer.returnValue(rows)
def get_pushers_stream_token(self):
- return self._pushers_id_gen.get_max_token()
+ return self._pushers_id_gen.get_current_token()
def get_all_updated_pushers(self, last_id, current_id, limit):
def get_all_updated_pushers_txn(txn):
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 6b9d848eaa..4befebc8e2 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore):
super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token()
+ "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
@cached(num_args=2)
@@ -221,7 +221,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue(results)
def get_max_receipt_stream_id(self):
- return self._receipts_id_gen.get_max_token()
+ return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
@@ -346,7 +346,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data
)
- max_persisted_id = self._stream_id_gen.get_max_token()
+ max_persisted_id = self._stream_id_gen.get_current_token()
defer.returnValue((stream_id, max_persisted_id))
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 7fc9a4f264..8644830657 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -458,4 +458,4 @@ class StateStore(SQLBaseStore):
)
def get_state_stream_token(self):
- return self._state_groups_id_gen.get_max_token()
+ return self._state_groups_id_gen.get_current_token()
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index cf84938be5..76bcd9cd00 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -539,7 +539,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'):
- token = yield self._stream_id_gen.get_max_token()
+ token = yield self._stream_id_gen.get_current_token()
if direction != 'b':
defer.returnValue("s%d" % (token,))
else:
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index a0e6b42b30..9da23f34cb 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore):
Returns:
A deferred int.
"""
- return self._account_data_id_gen.get_max_token()
+ return self._account_data_id_gen.get_current_token()
@cached()
def get_tags_for_user(self, user_id):
@@ -200,7 +200,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,))
- result = self._account_data_id_gen.get_max_token()
+ result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
@defer.inlineCallbacks
@@ -222,7 +222,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,))
- result = self._account_data_id_gen.get_max_token()
+ result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id):
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index a02dfc7d58..03f2aa6a5c 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -21,7 +21,7 @@ import threading
class IdGenerator(object):
def __init__(self, db_conn, table, column):
self._lock = threading.Lock()
- self._next_id = _load_max_id(db_conn, table, column)
+ self._next_id = _load_current_id(db_conn, table, column)
def get_next(self):
with self._lock:
@@ -29,12 +29,16 @@ class IdGenerator(object):
return self._next_id
-def _load_max_id(db_conn, table, column):
+def _load_current_id(db_conn, table, column, direction=1):
cur = db_conn.cursor()
- cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+ if direction == 1:
+ cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+ else:
+ cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
val, = cur.fetchone()
cur.close()
- return int(val) if val else 1
+ current_id = int(val) if val else direction
+ return (max if direction == 1 else min)(current_id, direction)
class StreamIdGenerator(object):
@@ -45,17 +49,30 @@ class StreamIdGenerator(object):
all ids less than or equal to it have completed. This handles the fact that
persistence of events can complete out of order.
+ :param connection db_conn: A database connection to use to fetch the
+ initial value of the generator from.
+ :param str table: A database table to read the initial value of the id
+ generator from.
+ :param str column: The column of the database table to read the initial
+ value from the id generator from.
+ :param list extra_tables: List of pairs of database tables and columns to
+ use to source the initial value of the generator from. The value with
+ the largest magnitude is used.
+ :param int direction: which direction the stream ids grow in. +1 to grow
+ upwards, -1 to grow downwards.
+
Usage:
with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
- def __init__(self, db_conn, table, column, extra_tables=[]):
+ def __init__(self, db_conn, table, column, extra_tables=[], direction=1):
self._lock = threading.Lock()
- self._current_max = _load_max_id(db_conn, table, column)
+ self._direction = direction
+ self._current = _load_current_id(db_conn, table, column, direction)
for table, column in extra_tables:
- self._current_max = max(
- self._current_max,
- _load_max_id(db_conn, table, column)
+ self._current = (max if direction > 0 else min)(
+ self._current,
+ _load_current_id(db_conn, table, column, direction)
)
self._unfinished_ids = deque()
@@ -66,8 +83,8 @@ class StreamIdGenerator(object):
# ... persist event ...
"""
with self._lock:
- self._current_max += 1
- next_id = self._current_max
+ self._current += self._direction
+ next_id = self._current
self._unfinished_ids.append(next_id)
@@ -88,8 +105,12 @@ class StreamIdGenerator(object):
# ... persist events ...
"""
with self._lock:
- next_ids = range(self._current_max + 1, self._current_max + n + 1)
- self._current_max += n
+ next_ids = range(
+ self._current + self._direction,
+ self._current + self._direction * (n + 1),
+ self._direction
+ )
+ self._current += n
for next_id in next_ids:
self._unfinished_ids.append(next_id)
@@ -105,15 +126,15 @@ class StreamIdGenerator(object):
return manager()
- def get_max_token(self):
+ def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
with self._lock:
if self._unfinished_ids:
- return self._unfinished_ids[0] - 1
+ return self._unfinished_ids[0] - self._direction
- return self._current_max
+ return self._current
class ChainedIdGenerator(object):
@@ -125,7 +146,7 @@ class ChainedIdGenerator(object):
def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator
self._lock = threading.Lock()
- self._current_max = _load_max_id(db_conn, table, column)
+ self._current_max = _load_current_id(db_conn, table, column)
self._unfinished_ids = deque()
def get_next(self):
@@ -137,7 +158,7 @@ class ChainedIdGenerator(object):
with self._lock:
self._current_max += 1
next_id = self._current_max
- chained_id = self.chained_generator.get_max_token()
+ chained_id = self.chained_generator.get_current_token()
self._unfinished_ids.append((next_id, chained_id))
@@ -151,7 +172,7 @@ class ChainedIdGenerator(object):
return manager()
- def get_max_token(self):
+ def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
@@ -160,4 +181,4 @@ class ChainedIdGenerator(object):
stream_id, chained_id = self._unfinished_ids[0]
return (stream_id - 1, chained_id)
- return (self._current_max, self.chained_generator.get_max_token())
+ return (self._current_max, self.chained_generator.get_current_token())
|