diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index a02dfc7d58..f69f1cdad4 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -21,7 +21,7 @@ import threading
class IdGenerator(object):
def __init__(self, db_conn, table, column):
self._lock = threading.Lock()
- self._next_id = _load_max_id(db_conn, table, column)
+ self._next_id = _load_current_id(db_conn, table, column)
def get_next(self):
with self._lock:
@@ -29,12 +29,16 @@ class IdGenerator(object):
return self._next_id
-def _load_max_id(db_conn, table, column):
+def _load_current_id(db_conn, table, column, step=1):
cur = db_conn.cursor()
- cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+ if step == 1:
+ cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+ else:
+ cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
val, = cur.fetchone()
cur.close()
- return int(val) if val else 1
+ current_id = int(val) if val else step
+ return (max if step > 0 else min)(current_id, step)
class StreamIdGenerator(object):
@@ -45,17 +49,32 @@ class StreamIdGenerator(object):
all ids less than or equal to it have completed. This handles the fact that
persistence of events can complete out of order.
+ Args:
+ db_conn(connection): A database connection to use to fetch the
+ initial value of the generator from.
+ table(str): A database table to read the initial value of the id
+ generator from.
+ column(str): The column of the database table to read the initial
+ value from the id generator from.
+ extra_tables(list): List of pairs of database tables and columns to
+ use to source the initial value of the generator from. The value
+ with the largest magnitude is used.
+ step(int): which direction the stream ids grow in. +1 to grow
+ upwards, -1 to grow downwards.
+
Usage:
with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
- def __init__(self, db_conn, table, column, extra_tables=[]):
+ def __init__(self, db_conn, table, column, extra_tables=[], step=1):
+ assert step != 0
self._lock = threading.Lock()
- self._current_max = _load_max_id(db_conn, table, column)
+ self._step = step
+ self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
- self._current_max = max(
- self._current_max,
- _load_max_id(db_conn, table, column)
+ self._current = (max if step > 0 else min)(
+ self._current,
+ _load_current_id(db_conn, table, column, step)
)
self._unfinished_ids = deque()
@@ -66,8 +85,8 @@ class StreamIdGenerator(object):
# ... persist event ...
"""
with self._lock:
- self._current_max += 1
- next_id = self._current_max
+ self._current += self._step
+ next_id = self._current
self._unfinished_ids.append(next_id)
@@ -88,8 +107,12 @@ class StreamIdGenerator(object):
# ... persist events ...
"""
with self._lock:
- next_ids = range(self._current_max + 1, self._current_max + n + 1)
- self._current_max += n
+ next_ids = range(
+ self._current + self._step,
+ self._current + self._step * (n + 1),
+ self._step
+ )
+ self._current += n
for next_id in next_ids:
self._unfinished_ids.append(next_id)
@@ -105,15 +128,15 @@ class StreamIdGenerator(object):
return manager()
- def get_max_token(self):
+ def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
with self._lock:
if self._unfinished_ids:
- return self._unfinished_ids[0] - 1
+ return self._unfinished_ids[0] - self._step
- return self._current_max
+ return self._current
class ChainedIdGenerator(object):
@@ -125,7 +148,7 @@ class ChainedIdGenerator(object):
def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator
self._lock = threading.Lock()
- self._current_max = _load_max_id(db_conn, table, column)
+ self._current_max = _load_current_id(db_conn, table, column)
self._unfinished_ids = deque()
def get_next(self):
@@ -137,7 +160,7 @@ class ChainedIdGenerator(object):
with self._lock:
self._current_max += 1
next_id = self._current_max
- chained_id = self.chained_generator.get_max_token()
+ chained_id = self.chained_generator.get_current_token()
self._unfinished_ids.append((next_id, chained_id))
@@ -151,7 +174,7 @@ class ChainedIdGenerator(object):
return manager()
- def get_max_token(self):
+ def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
@@ -160,4 +183,4 @@ class ChainedIdGenerator(object):
stream_id, chained_id = self._unfinished_ids[0]
return (stream_id - 1, chained_id)
- return (self._current_max, self.chained_generator.get_max_token())
+ return (self._current_max, self.chained_generator.get_current_token())
|