summary refs log tree commit diff
path: root/synapse/storage/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/util')
-rw-r--r--synapse/storage/util/id_generators.py84
1 files changed, 58 insertions, 26 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index efe3f68e6e..af425ba9a4 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -20,23 +20,21 @@ import threading
 
 class IdGenerator(object):
     def __init__(self, db_conn, table, column):
-        self.table = table
-        self.column = column
         self._lock = threading.Lock()
-        cur = db_conn.cursor()
-        self._next_id = self._load_next_id(cur)
-        cur.close()
-
-    def _load_next_id(self, txn):
-        txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,))
-        val, = txn.fetchone()
-        return val + 1 if val else 1
+        self._next_id = _load_max_id(db_conn, table, column)
 
     def get_next(self):
         with self._lock:
-            i = self._next_id
             self._next_id += 1
-            return i
+            return self._next_id
+
+
+def _load_max_id(db_conn, table, column):
+    cur = db_conn.cursor()
+    cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+    val, = cur.fetchone()
+    cur.close()
+    return val if val else 1
 
 
 class StreamIdGenerator(object):
@@ -52,23 +50,10 @@ class StreamIdGenerator(object):
             # ... persist event ...
     """
     def __init__(self, db_conn, table, column):
-        self.table = table
-        self.column = column
-
         self._lock = threading.Lock()
-
-        cur = db_conn.cursor()
-        self._current_max = self._load_current_max(cur)
-        cur.close()
-
+        self._current_max = _load_max_id(db_conn, table, column)
         self._unfinished_ids = deque()
 
-    def _load_current_max(self, txn):
-        txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
-        rows = txn.fetchall()
-        val, = rows[0]
-        return int(val) if val else 1
-
     def get_next(self):
         """
         Usage:
@@ -124,3 +109,50 @@ class StreamIdGenerator(object):
                 return self._unfinished_ids[0] - 1
 
             return self._current_max
+
+
+class ChainedIdGenerator(object):
+    """Used to generate new stream ids where the stream must be kept in sync
+    with another stream. It generates pairs of IDs, the first element is an
+    integer ID for this stream, the second element is the ID for the stream
+    that this stream needs to be kept in sync with."""
+
+    def __init__(self, chained_generator, db_conn, table, column):
+        self.chained_generator = chained_generator
+        self._lock = threading.Lock()
+        self._current_max = _load_max_id(db_conn, table, column)
+        self._unfinished_ids = deque()
+
+    def get_next(self):
+        """
+        Usage:
+            with stream_id_gen.get_next() as (stream_id, chained_id):
+                # ... persist event ...
+        """
+        with self._lock:
+            self._current_max += 1
+            next_id = self._current_max
+            chained_id = self.chained_generator.get_max_token()
+
+            self._unfinished_ids.append((next_id, chained_id))
+
+        @contextlib.contextmanager
+        def manager():
+            try:
+                yield (next_id, chained_id)
+            finally:
+                with self._lock:
+                    self._unfinished_ids.remove((next_id, chained_id))
+
+        return manager()
+
+    def get_max_token(self):
+        """Returns the maximum stream id such that all stream ids less than or
+        equal to it have been successfully persisted.
+        """
+        with self._lock:
+            if self._unfinished_ids:
+                stream_id, chained_id = self._unfinished_ids[0]
+                return (stream_id - 1, chained_id)
+
+            return (self._current_max, self.chained_generator.get_max_token())