summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/util/id_generators.py30
1 files changed, 16 insertions, 14 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9d461d5e96..54aeff2b43 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -86,10 +86,10 @@ class StreamIdGenerator(object):
             with stream_id_gen.get_next_txn(txn) as stream_id:
                 # ... persist event ...
         """
-        with self._lock:
-            if not self._current_max:
-                self._compute_current_max(txn)
+        if not self._current_max:
+            self._compute_current_max(txn)
 
+        with self._lock:
             self._current_max += 1
             next_id = self._current_max
 
@@ -110,22 +110,24 @@ class StreamIdGenerator(object):
         """Returns the maximum stream id such that all stream ids less than or
         equal to it have been successfully persisted.
         """
+        if not self._current_max:
+            yield store.runInteraction(
+                "_compute_current_max",
+                self._get_or_compute_current_max,
+            )
+
         with self._lock:
             if self._unfinished_ids:
                 defer.returnValue(self._unfinished_ids[0] - 1)
 
-            if not self._current_max:
-                yield store.runInteraction(
-                    "_compute_current_max",
-                    self._compute_current_max,
-                )
-
             defer.returnValue(self._current_max)
 
-    def _compute_current_max(self, txn):
-        txn.execute("SELECT MAX(stream_ordering) FROM events")
-        val, = txn.fetchone()
+    def _get_or_compute_current_max(self, txn):
+        with self._lock:
+            txn.execute("SELECT MAX(stream_ordering) FROM events")
+            rows = txn.fetchall()
+            val, = rows[0]
 
-        self._current_max = int(val) if val else 1
+            self._current_max = int(val) if val else 1
 
-        return self._current_max
+            return self._current_max