summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py54
1 files changed, 48 insertions, 6 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index de4f661973..9f63f07080 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -163,8 +163,8 @@ class LoggingTransaction(object):
             return self.txn.execute(
                 sql, *args, **kwargs
             )
-        except:
-                logger.exception("[SQL FAIL] {%s}", self.name)
+        except Exception as e:
+                logger.debug("[SQL FAIL] {%s} %s", self.name, e)
                 raise
         finally:
             msecs = (time.time() * 1000) - start
@@ -209,6 +209,46 @@ class PerformanceCounters(object):
         return top_n_counters
 
 
+class IdGenerator(object):
+    def __init__(self, table, column, store):
+        self.table = table
+        self.column = column
+        self.store = store
+        self._lock = threading.Lock()
+        self._next_id = None
+
+    @defer.inlineCallbacks
+    def get_next(self):
+        with self._lock:
+            if not self._next_id:
+                res = yield self.store._execute_and_decode(
+                    "IdGenerator_%s" % (self.table,),
+                    "SELECT MAX(%s) as mx FROM %s" % (self.column, self.table,)
+                )
+
+                self._next_id = (res and res[0] and res[0]["mx"]) or 1
+
+            i = self._next_id
+            self._next_id += 1
+            defer.returnValue(i)
+
+    def get_next_txn(self, txn):
+        with self._lock:
+            if self._next_id:
+                i = self._next_id
+                self._next_id += 1
+                return i
+            else:
+                txn.execute(
+                    "SELECT MAX(%s) FROM %s" % (self.column, self.table,)
+                )
+
+                val, = txn.fetchone()
+                self._next_id = val or 2
+
+                return 1
+
+
 class SQLBaseStore(object):
     _TXN_ID = 0
 
@@ -234,8 +274,10 @@ class SQLBaseStore(object):
         # Pretend the getEventCache is just another named cache
         caches_by_name["*getEvent*"] = self._get_event_cache
 
-        self._next_stream_id_lock = threading.Lock()
-        self._next_stream_id = int(hs.get_clock().time_msec()) * 1000
+        self._stream_id_gen = IdGenerator("events", "stream_ordering", self)
+        self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
+        self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
+        self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
 
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
@@ -292,8 +334,8 @@ class SQLBaseStore(object):
                         LoggingTransaction(txn, name, self.database_engine),
                         *args, **kwargs
                     )
-                except:
-                    logger.exception("[TXN FAIL] {%s}", name)
+                except Exception as e:
+                    logger.debug("[TXN FAIL] {%s}", name, e)
                     raise
                 finally:
                     end = time.time() * 1000