diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 23289bbdd4..badf9a5f40 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -22,6 +22,8 @@ from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from synapse.util.lrucache import LruCache
import synapse.metrics
+from util.id_generators import IdGenerator, StreamIdGenerator
+
from twisted.internet import defer
from collections import namedtuple, OrderedDict
@@ -29,7 +31,6 @@ import functools
import simplejson as json
import sys
import time
-import threading
logger = logging.getLogger(__name__)
@@ -232,46 +233,6 @@ class PerformanceCounters(object):
return top_n_counters
-class IdGenerator(object):
- def __init__(self, table, column, store):
- self.table = table
- self.column = column
- self.store = store
- self._lock = threading.Lock()
- self._next_id = None
-
- @defer.inlineCallbacks
- def get_next(self):
- with self._lock:
- if not self._next_id:
- res = yield self.store._execute_and_decode(
- "IdGenerator_%s" % (self.table,),
- "SELECT MAX(%s) as mx FROM %s" % (self.column, self.table,)
- )
-
- self._next_id = (res and res[0] and res[0]["mx"]) or 1
-
- i = self._next_id
- self._next_id += 1
- defer.returnValue(i)
-
- def get_next_txn(self, txn):
- with self._lock:
- if self._next_id:
- i = self._next_id
- self._next_id += 1
- return i
- else:
- txn.execute(
- "SELECT MAX(%s) FROM %s" % (self.column, self.table,)
- )
-
- val, = txn.fetchone()
- self._next_id = val or 2
-
- return 1
-
-
class SQLBaseStore(object):
_TXN_ID = 0
@@ -297,7 +258,7 @@ class SQLBaseStore(object):
# Pretend the getEventCache is just another named cache
caches_by_name["*getEvent*"] = self._get_event_cache
- self._stream_id_gen = IdGenerator("events", "stream_ordering", self)
+ self._stream_id_gen = StreamIdGenerator()
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
@@ -363,7 +324,6 @@ class SQLBaseStore(object):
*args, **kwargs
)
except self.database_engine.module.DatabaseError as e:
- logger.warn("[TXN DEADLOCK] {%s} %r, %r", name, e.errno, e.sqlstate)
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
|