summary refs log tree commit diff
path: root/synapse/storage/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/util')
-rw-r--r--synapse/storage/util/id_generators.py61
1 files changed, 41 insertions, 20 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index a02dfc7d58..03f2aa6a5c 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -21,7 +21,7 @@ import threading
 class IdGenerator(object):
     def __init__(self, db_conn, table, column):
         self._lock = threading.Lock()
-        self._next_id = _load_max_id(db_conn, table, column)
+        self._next_id = _load_current_id(db_conn, table, column)
 
     def get_next(self):
         with self._lock:
@@ -29,12 +29,16 @@ class IdGenerator(object):
             return self._next_id
 
 
-def _load_max_id(db_conn, table, column):
+def _load_current_id(db_conn, table, column, direction=1):
     cur = db_conn.cursor()
-    cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+    if direction == 1:
+        cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+    else:
+        cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
     val, = cur.fetchone()
     cur.close()
-    return int(val) if val else 1
+    current_id = int(val) if val else direction
+    return (max if direction == 1 else min)(current_id, direction)
 
 
 class StreamIdGenerator(object):
@@ -45,17 +49,30 @@ class StreamIdGenerator(object):
     all ids less than or equal to it have completed. This handles the fact that
     persistence of events can complete out of order.
 
+    :param connection db_conn:  A database connection to use to fetch the
+        initial value of the generator from.
+    :param str table: A database table to read the initial value of the id
+        generator from.
+    :param str column: The column of the database table to read the initial
+        value from the id generator from.
+    :param list extra_tables: List of pairs of database tables and columns to
+        use to source the initial value of the generator from. The value with
+        the largest magnitude is used.
+    :param int direction: which direction the stream ids grow in. +1 to grow
+        upwards, -1 to grow downwards.
+
     Usage:
         with stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
-    def __init__(self, db_conn, table, column, extra_tables=[]):
+    def __init__(self, db_conn, table, column, extra_tables=[], direction=1):
         self._lock = threading.Lock()
-        self._current_max = _load_max_id(db_conn, table, column)
+        self._direction = direction
+        self._current = _load_current_id(db_conn, table, column, direction)
         for table, column in extra_tables:
-            self._current_max = max(
-                self._current_max,
-                _load_max_id(db_conn, table, column)
+            self._current = (max if direction > 0 else min)(
+                self._current,
+                _load_current_id(db_conn, table, column, direction)
             )
         self._unfinished_ids = deque()
 
@@ -66,8 +83,8 @@ class StreamIdGenerator(object):
                 # ... persist event ...
         """
         with self._lock:
-            self._current_max += 1
-            next_id = self._current_max
+            self._current += self._direction
+            next_id = self._current
 
             self._unfinished_ids.append(next_id)
 
@@ -88,8 +105,12 @@ class StreamIdGenerator(object):
                 # ... persist events ...
         """
         with self._lock:
-            next_ids = range(self._current_max + 1, self._current_max + n + 1)
-            self._current_max += n
+            next_ids = range(
+                self._current + self._direction,
+                self._current + self._direction * (n + 1),
+                self._direction
+            )
+            self._current += n
 
             for next_id in next_ids:
                 self._unfinished_ids.append(next_id)
@@ -105,15 +126,15 @@ class StreamIdGenerator(object):
 
         return manager()
 
-    def get_max_token(self):
+    def get_current_token(self):
         """Returns the maximum stream id such that all stream ids less than or
         equal to it have been successfully persisted.
         """
         with self._lock:
             if self._unfinished_ids:
-                return self._unfinished_ids[0] - 1
+                return self._unfinished_ids[0] - self._direction
 
-            return self._current_max
+            return self._current
 
 
 class ChainedIdGenerator(object):
@@ -125,7 +146,7 @@ class ChainedIdGenerator(object):
     def __init__(self, chained_generator, db_conn, table, column):
         self.chained_generator = chained_generator
         self._lock = threading.Lock()
-        self._current_max = _load_max_id(db_conn, table, column)
+        self._current_max = _load_current_id(db_conn, table, column)
         self._unfinished_ids = deque()
 
     def get_next(self):
@@ -137,7 +158,7 @@ class ChainedIdGenerator(object):
         with self._lock:
             self._current_max += 1
             next_id = self._current_max
-            chained_id = self.chained_generator.get_max_token()
+            chained_id = self.chained_generator.get_current_token()
 
             self._unfinished_ids.append((next_id, chained_id))
 
@@ -151,7 +172,7 @@ class ChainedIdGenerator(object):
 
         return manager()
 
-    def get_max_token(self):
+    def get_current_token(self):
         """Returns the maximum stream id such that all stream ids less than or
         equal to it have been successfully persisted.
         """
@@ -160,4 +181,4 @@ class ChainedIdGenerator(object):
                 stream_id, chained_id = self._unfinished_ids[0]
                 return (stream_id - 1, chained_id)
 
-            return (self._current_max, self.chained_generator.get_max_token())
+            return (self._current_max, self.chained_generator.get_current_token())