diff options
Diffstat (limited to 'synapse/storage/util')
-rw-r--r-- | synapse/storage/util/id_generators.py | 84 |
1 files changed, 58 insertions, 26 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index efe3f68e6e..af425ba9a4 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -20,23 +20,21 @@ import threading class IdGenerator(object): def __init__(self, db_conn, table, column): - self.table = table - self.column = column self._lock = threading.Lock() - cur = db_conn.cursor() - self._next_id = self._load_next_id(cur) - cur.close() - - def _load_next_id(self, txn): - txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,)) - val, = txn.fetchone() - return val + 1 if val else 1 + self._next_id = _load_max_id(db_conn, table, column) def get_next(self): with self._lock: - i = self._next_id self._next_id += 1 - return i + return self._next_id + + +def _load_max_id(db_conn, table, column): + cur = db_conn.cursor() + cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + val, = cur.fetchone() + cur.close() + return val if val else 1 class StreamIdGenerator(object): @@ -52,23 +50,10 @@ class StreamIdGenerator(object): # ... persist event ... """ def __init__(self, db_conn, table, column): - self.table = table - self.column = column - self._lock = threading.Lock() - - cur = db_conn.cursor() - self._current_max = self._load_current_max(cur) - cur.close() - + self._current_max = _load_max_id(db_conn, table, column) self._unfinished_ids = deque() - def _load_current_max(self, txn): - txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) - rows = txn.fetchall() - val, = rows[0] - return int(val) if val else 1 - def get_next(self): """ Usage: @@ -124,3 +109,50 @@ class StreamIdGenerator(object): return self._unfinished_ids[0] - 1 return self._current_max + + +class ChainedIdGenerator(object): + """Used to generate new stream ids where the stream must be kept in sync + with another stream. It generates pairs of IDs, the first element is an + integer ID for this stream, the second element is the ID for the stream + that this stream needs to be kept in sync with.""" + + def __init__(self, chained_generator, db_conn, table, column): + self.chained_generator = chained_generator + self._lock = threading.Lock() + self._current_max = _load_max_id(db_conn, table, column) + self._unfinished_ids = deque() + + def get_next(self): + """ + Usage: + with stream_id_gen.get_next() as (stream_id, chained_id): + # ... persist event ... + """ + with self._lock: + self._current_max += 1 + next_id = self._current_max + chained_id = self.chained_generator.get_max_token() + + self._unfinished_ids.append((next_id, chained_id)) + + @contextlib.contextmanager + def manager(): + try: + yield (next_id, chained_id) + finally: + with self._lock: + self._unfinished_ids.remove((next_id, chained_id)) + + return manager() + + def get_max_token(self): + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + """ + with self._lock: + if self._unfinished_ids: + stream_id, chained_id = self._unfinished_ids[0] + return (stream_id - 1, chained_id) + + return (self._current_max, self.chained_generator.get_max_token()) |