diff options
-rw-r--r-- | synapse/storage/__init__.py | 20 | ||||
-rw-r--r-- | synapse/storage/account_data.py | 4 | ||||
-rw-r--r-- | synapse/storage/events.py | 19 | ||||
-rw-r--r-- | synapse/storage/presence.py | 6 | ||||
-rw-r--r-- | synapse/storage/push_rule.py | 2 | ||||
-rw-r--r-- | synapse/storage/pusher.py | 2 | ||||
-rw-r--r-- | synapse/storage/receipts.py | 6 | ||||
-rw-r--r-- | synapse/storage/state.py | 2 | ||||
-rw-r--r-- | synapse/storage/stream.py | 2 | ||||
-rw-r--r-- | synapse/storage/tags.py | 6 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 63 |
11 files changed, 71 insertions, 61 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index aaad38039e..57863bba4d 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -88,15 +88,6 @@ class DataStore(RoomMemberStore, RoomStore, self.hs = hs self.database_engine = hs.database_engine - cur = db_conn.cursor() - try: - cur.execute("SELECT MIN(stream_ordering) FROM events",) - rows = cur.fetchall() - self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1 - self.min_stream_token = min(self.min_stream_token, -1) - finally: - cur.close() - self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, @@ -105,6 +96,9 @@ class DataStore(RoomMemberStore, RoomStore, self._stream_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering" ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", step=-1 + ) self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) @@ -129,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore, extra_tables=[("deleted_pushers", "stream_id")], ) - events_max = self._stream_id_gen.get_max_token() + events_max = self._stream_id_gen.get_current_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", @@ -145,7 +139,7 @@ class DataStore(RoomMemberStore, RoomStore, "MembershipStreamChangeCache", events_max, ) - account_max = self._account_data_id_gen.get_max_token() + account_max = self._account_data_id_gen.get_current_token() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) @@ -156,7 +150,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", - max_value=self._presence_id_gen.get_max_token(), + max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, @@ -167,7 +161,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", - max_value=self._push_rules_stream_id_gen.get_max_token()[0], + max_value=self._push_rules_stream_id_gen.get_current_token()[0], ) self.push_rules_stream_cache = StreamChangeCache( diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index faddefe219..7a7fbf1e52 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -200,7 +200,7 @@ class AccountDataStore(SQLBaseStore): "add_room_account_data", add_account_data_txn, next_id ) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @defer.inlineCallbacks @@ -239,7 +239,7 @@ class AccountDataStore(SQLBaseStore): "add_user_account_data", add_account_data_txn, next_id ) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) def _update_max_stream_id(self, txn, next_id): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 7468e6e00c..c4dc3b3d51 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -24,7 +24,6 @@ from synapse.util.logutils import log_function from synapse.api.constants import EventTypes from canonicaljson import encode_canonical_json -from contextlib import contextmanager from collections import namedtuple import logging @@ -66,14 +65,9 @@ class EventsStore(SQLBaseStore): return if backfilled: - start = self.min_stream_token - 1 - self.min_stream_token -= len(events_and_contexts) + 1 - stream_orderings = range(start, self.min_stream_token, -1) - - @contextmanager - def stream_ordering_manager(): - yield stream_orderings - stream_ordering_manager = stream_ordering_manager() + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(events_and_contexts) + ) else: stream_ordering_manager = self._stream_id_gen.get_next_mult( len(events_and_contexts) @@ -130,7 +124,7 @@ class EventsStore(SQLBaseStore): except _RollbackButIsFineException: pass - max_persisted_id = yield self._stream_id_gen.get_max_token() + max_persisted_id = yield self._stream_id_gen.get_current_token() defer.returnValue((stream_ordering, max_persisted_id)) @defer.inlineCallbacks @@ -1117,10 +1111,7 @@ class EventsStore(SQLBaseStore): def get_current_backfill_token(self): """The current minimum token that backfilled events have reached""" - - # TODO: Fix race with the persit_event txn by using one of the - # stream id managers - return -self.min_stream_token + return -self._backfill_id_gen.get_current_token() def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 4cec31e316..59b4ef5ce6 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -68,7 +68,9 @@ class PresenceStore(SQLBaseStore): self._update_presence_txn, stream_orderings, presence_states, ) - defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) + defer.returnValue(( + stream_orderings[-1], self._presence_id_gen.get_current_token() + )) def _update_presence_txn(self, txn, stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states): @@ -155,7 +157,7 @@ class PresenceStore(SQLBaseStore): defer.returnValue([UserPresenceState(**row) for row in rows]) def get_current_presence_token(self): - return self._presence_id_gen.get_max_token() + return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9dbad2fd5f..d2bf7f2aec 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -392,7 +392,7 @@ class PushRuleStore(SQLBaseStore): """Get the position of the push rules stream. Returns a pair of a stream id for the push_rules stream and the room stream ordering it corresponds to.""" - return self._push_rules_stream_id_gen.get_max_token() + return self._push_rules_stream_id_gen.get_current_token() def have_push_rules_changed_for_user(self, user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 87b2ac5773..d1669c778a 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -78,7 +78,7 @@ class PusherStore(SQLBaseStore): defer.returnValue(rows) def get_pushers_stream_token(self): - return self._pushers_id_gen.get_max_token() + return self._pushers_id_gen.get_current_token() def get_all_updated_pushers(self, last_id, current_id, limit): def get_all_updated_pushers_txn(txn): diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 6b9d848eaa..4befebc8e2 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore): super(ReceiptsStore, self).__init__(hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() ) @cached(num_args=2) @@ -221,7 +221,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue(results) def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_max_token() + return self._receipts_id_gen.get_current_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): @@ -346,7 +346,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = self._stream_id_gen.get_max_token() + max_persisted_id = self._stream_id_gen.get_current_token() defer.returnValue((stream_id, max_persisted_id)) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7fc9a4f264..8644830657 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -458,4 +458,4 @@ class StateStore(SQLBaseStore): ) def get_state_stream_token(self): - return self._state_groups_id_gen.get_max_token() + return self._state_groups_id_gen.get_current_token() diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index cf84938be5..76bcd9cd00 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -539,7 +539,7 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def get_room_events_max_id(self, direction='f'): - token = yield self._stream_id_gen.get_max_token() + token = yield self._stream_id_gen.get_current_token() if direction != 'b': defer.returnValue("s%d" % (token,)) else: diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index a0e6b42b30..9da23f34cb 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore): Returns: A deferred int. """ - return self._account_data_id_gen.get_max_token() + return self._account_data_id_gen.get_current_token() @cached() def get_tags_for_user(self, user_id): @@ -200,7 +200,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @defer.inlineCallbacks @@ -222,7 +222,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index a02dfc7d58..f69f1cdad4 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -21,7 +21,7 @@ import threading class IdGenerator(object): def __init__(self, db_conn, table, column): self._lock = threading.Lock() - self._next_id = _load_max_id(db_conn, table, column) + self._next_id = _load_current_id(db_conn, table, column) def get_next(self): with self._lock: @@ -29,12 +29,16 @@ class IdGenerator(object): return self._next_id -def _load_max_id(db_conn, table, column): +def _load_current_id(db_conn, table, column, step=1): cur = db_conn.cursor() - cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + if step == 1: + cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + else: + cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) val, = cur.fetchone() cur.close() - return int(val) if val else 1 + current_id = int(val) if val else step + return (max if step > 0 else min)(current_id, step) class StreamIdGenerator(object): @@ -45,17 +49,32 @@ class StreamIdGenerator(object): all ids less than or equal to it have completed. This handles the fact that persistence of events can complete out of order. + Args: + db_conn(connection): A database connection to use to fetch the + initial value of the generator from. + table(str): A database table to read the initial value of the id + generator from. + column(str): The column of the database table to read the initial + value from the id generator from. + extra_tables(list): List of pairs of database tables and columns to + use to source the initial value of the generator from. The value + with the largest magnitude is used. + step(int): which direction the stream ids grow in. +1 to grow + upwards, -1 to grow downwards. + Usage: with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[]): + def __init__(self, db_conn, table, column, extra_tables=[], step=1): + assert step != 0 self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._step = step + self._current = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: - self._current_max = max( - self._current_max, - _load_max_id(db_conn, table, column) + self._current = (max if step > 0 else min)( + self._current, + _load_current_id(db_conn, table, column, step) ) self._unfinished_ids = deque() @@ -66,8 +85,8 @@ class StreamIdGenerator(object): # ... persist event ... """ with self._lock: - self._current_max += 1 - next_id = self._current_max + self._current += self._step + next_id = self._current self._unfinished_ids.append(next_id) @@ -88,8 +107,12 @@ class StreamIdGenerator(object): # ... persist events ... """ with self._lock: - next_ids = range(self._current_max + 1, self._current_max + n + 1) - self._current_max += n + next_ids = range( + self._current + self._step, + self._current + self._step * (n + 1), + self._step + ) + self._current += n for next_id in next_ids: self._unfinished_ids.append(next_id) @@ -105,15 +128,15 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - 1 + return self._unfinished_ids[0] - self._step - return self._current_max + return self._current class ChainedIdGenerator(object): @@ -125,7 +148,7 @@ class ChainedIdGenerator(object): def __init__(self, chained_generator, db_conn, table, column): self.chained_generator = chained_generator self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._current_max = _load_current_id(db_conn, table, column) self._unfinished_ids = deque() def get_next(self): @@ -137,7 +160,7 @@ class ChainedIdGenerator(object): with self._lock: self._current_max += 1 next_id = self._current_max - chained_id = self.chained_generator.get_max_token() + chained_id = self.chained_generator.get_current_token() self._unfinished_ids.append((next_id, chained_id)) @@ -151,7 +174,7 @@ class ChainedIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ @@ -160,4 +183,4 @@ class ChainedIdGenerator(object): stream_id, chained_id = self._unfinished_ids[0] return (stream_id - 1, chained_id) - return (self._current_max, self.chained_generator.get_max_token()) + return (self._current_max, self.chained_generator.get_current_token()) |