diff options
Diffstat (limited to 'synapse/storage/util')
-rw-r--r-- | synapse/storage/util/id_generators.py | 63 |
1 files changed, 43 insertions, 20 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index a02dfc7d58..f69f1cdad4 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -21,7 +21,7 @@ import threading class IdGenerator(object): def __init__(self, db_conn, table, column): self._lock = threading.Lock() - self._next_id = _load_max_id(db_conn, table, column) + self._next_id = _load_current_id(db_conn, table, column) def get_next(self): with self._lock: @@ -29,12 +29,16 @@ class IdGenerator(object): return self._next_id -def _load_max_id(db_conn, table, column): +def _load_current_id(db_conn, table, column, step=1): cur = db_conn.cursor() - cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + if step == 1: + cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + else: + cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) val, = cur.fetchone() cur.close() - return int(val) if val else 1 + current_id = int(val) if val else step + return (max if step > 0 else min)(current_id, step) class StreamIdGenerator(object): @@ -45,17 +49,32 @@ class StreamIdGenerator(object): all ids less than or equal to it have completed. This handles the fact that persistence of events can complete out of order. + Args: + db_conn(connection): A database connection to use to fetch the + initial value of the generator from. + table(str): A database table to read the initial value of the id + generator from. + column(str): The column of the database table to read the initial + value from the id generator from. + extra_tables(list): List of pairs of database tables and columns to + use to source the initial value of the generator from. The value + with the largest magnitude is used. + step(int): which direction the stream ids grow in. +1 to grow + upwards, -1 to grow downwards. + Usage: with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[]): + def __init__(self, db_conn, table, column, extra_tables=[], step=1): + assert step != 0 self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._step = step + self._current = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: - self._current_max = max( - self._current_max, - _load_max_id(db_conn, table, column) + self._current = (max if step > 0 else min)( + self._current, + _load_current_id(db_conn, table, column, step) ) self._unfinished_ids = deque() @@ -66,8 +85,8 @@ class StreamIdGenerator(object): # ... persist event ... """ with self._lock: - self._current_max += 1 - next_id = self._current_max + self._current += self._step + next_id = self._current self._unfinished_ids.append(next_id) @@ -88,8 +107,12 @@ class StreamIdGenerator(object): # ... persist events ... """ with self._lock: - next_ids = range(self._current_max + 1, self._current_max + n + 1) - self._current_max += n + next_ids = range( + self._current + self._step, + self._current + self._step * (n + 1), + self._step + ) + self._current += n for next_id in next_ids: self._unfinished_ids.append(next_id) @@ -105,15 +128,15 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - 1 + return self._unfinished_ids[0] - self._step - return self._current_max + return self._current class ChainedIdGenerator(object): @@ -125,7 +148,7 @@ class ChainedIdGenerator(object): def __init__(self, chained_generator, db_conn, table, column): self.chained_generator = chained_generator self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._current_max = _load_current_id(db_conn, table, column) self._unfinished_ids = deque() def get_next(self): @@ -137,7 +160,7 @@ class ChainedIdGenerator(object): with self._lock: self._current_max += 1 next_id = self._current_max - chained_id = self.chained_generator.get_max_token() + chained_id = self.chained_generator.get_current_token() self._unfinished_ids.append((next_id, chained_id)) @@ -151,7 +174,7 @@ class ChainedIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ @@ -160,4 +183,4 @@ class ChainedIdGenerator(object): stream_id, chained_id = self._unfinished_ids[0] return (stream_id - 1, chained_id) - return (self._current_max, self.chained_generator.get_max_token()) + return (self._current_max, self.chained_generator.get_current_token()) |