summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py46
1 files changed, 3 insertions, 43 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 23289bbdd4..badf9a5f40 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -22,6 +22,8 @@ from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
 from synapse.util.lrucache import LruCache
 import synapse.metrics
 
+from util.id_generators import IdGenerator, StreamIdGenerator
+
 from twisted.internet import defer
 
 from collections import namedtuple, OrderedDict
@@ -29,7 +31,6 @@ import functools
 import simplejson as json
 import sys
 import time
-import threading
 
 
 logger = logging.getLogger(__name__)
@@ -232,46 +233,6 @@ class PerformanceCounters(object):
         return top_n_counters
 
 
-class IdGenerator(object):
-    def __init__(self, table, column, store):
-        self.table = table
-        self.column = column
-        self.store = store
-        self._lock = threading.Lock()
-        self._next_id = None
-
-    @defer.inlineCallbacks
-    def get_next(self):
-        with self._lock:
-            if not self._next_id:
-                res = yield self.store._execute_and_decode(
-                    "IdGenerator_%s" % (self.table,),
-                    "SELECT MAX(%s) as mx FROM %s" % (self.column, self.table,)
-                )
-
-                self._next_id = (res and res[0] and res[0]["mx"]) or 1
-
-            i = self._next_id
-            self._next_id += 1
-            defer.returnValue(i)
-
-    def get_next_txn(self, txn):
-        with self._lock:
-            if self._next_id:
-                i = self._next_id
-                self._next_id += 1
-                return i
-            else:
-                txn.execute(
-                    "SELECT MAX(%s) FROM %s" % (self.column, self.table,)
-                )
-
-                val, = txn.fetchone()
-                self._next_id = val or 2
-
-                return 1
-
-
 class SQLBaseStore(object):
     _TXN_ID = 0
 
@@ -297,7 +258,7 @@ class SQLBaseStore(object):
         # Pretend the getEventCache is just another named cache
         caches_by_name["*getEvent*"] = self._get_event_cache
 
-        self._stream_id_gen = IdGenerator("events", "stream_ordering", self)
+        self._stream_id_gen = StreamIdGenerator()
         self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
         self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
         self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
@@ -363,7 +324,6 @@ class SQLBaseStore(object):
                                 *args, **kwargs
                             )
                         except self.database_engine.module.DatabaseError as e:
-                            logger.warn("[TXN DEADLOCK] {%s} %r, %r", name, e.errno, e.sqlstate)
                             if self.database_engine.is_deadlock(e):
                                 logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
                                 if i < N: