diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index efe3f68e6e..af425ba9a4 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -20,23 +20,21 @@ import threading
class IdGenerator(object):
def __init__(self, db_conn, table, column):
- self.table = table
- self.column = column
self._lock = threading.Lock()
- cur = db_conn.cursor()
- self._next_id = self._load_next_id(cur)
- cur.close()
-
- def _load_next_id(self, txn):
- txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,))
- val, = txn.fetchone()
- return val + 1 if val else 1
+ self._next_id = _load_max_id(db_conn, table, column)
def get_next(self):
with self._lock:
- i = self._next_id
self._next_id += 1
- return i
+ return self._next_id
+
+
+def _load_max_id(db_conn, table, column):
+ cur = db_conn.cursor()
+ cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+ val, = cur.fetchone()
+ cur.close()
+ return val if val else 1
class StreamIdGenerator(object):
@@ -52,23 +50,10 @@ class StreamIdGenerator(object):
# ... persist event ...
"""
def __init__(self, db_conn, table, column):
- self.table = table
- self.column = column
-
self._lock = threading.Lock()
-
- cur = db_conn.cursor()
- self._current_max = self._load_current_max(cur)
- cur.close()
-
+ self._current_max = _load_max_id(db_conn, table, column)
self._unfinished_ids = deque()
- def _load_current_max(self, txn):
- txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
- rows = txn.fetchall()
- val, = rows[0]
- return int(val) if val else 1
-
def get_next(self):
"""
Usage:
@@ -124,3 +109,50 @@ class StreamIdGenerator(object):
return self._unfinished_ids[0] - 1
return self._current_max
+
+
+class ChainedIdGenerator(object):
+ """Used to generate new stream ids where the stream must be kept in sync
+ with another stream. It generates pairs of IDs, the first element is an
+ integer ID for this stream, the second element is the ID for the stream
+ that this stream needs to be kept in sync with."""
+
+ def __init__(self, chained_generator, db_conn, table, column):
+ self.chained_generator = chained_generator
+ self._lock = threading.Lock()
+ self._current_max = _load_max_id(db_conn, table, column)
+ self._unfinished_ids = deque()
+
+ def get_next(self):
+ """
+ Usage:
+ with stream_id_gen.get_next() as (stream_id, chained_id):
+ # ... persist event ...
+ """
+ with self._lock:
+ self._current_max += 1
+ next_id = self._current_max
+ chained_id = self.chained_generator.get_max_token()
+
+ self._unfinished_ids.append((next_id, chained_id))
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield (next_id, chained_id)
+ finally:
+ with self._lock:
+ self._unfinished_ids.remove((next_id, chained_id))
+
+ return manager()
+
+ def get_max_token(self):
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
+ """
+ with self._lock:
+ if self._unfinished_ids:
+ stream_id, chained_id = self._unfinished_ids[0]
+ return (stream_id - 1, chained_id)
+
+ return (self._current_max, self.chained_generator.get_max_token())
|