diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 76ec3ee93f..047d100f46 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -29,6 +29,7 @@ import functools
import simplejson as json
import sys
import time
+import threading
logger = logging.getLogger(__name__)
@@ -118,19 +119,16 @@ def cached(max_entries=1000, num_args=1):
return wrap
-def _convert_param_style(sql):
- return sql.replace("?", "%s")
-
-
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
- __slots__ = ["txn", "name"]
+ __slots__ = ["txn", "name", "database_engine"]
- def __init__(self, txn, name):
+ def __init__(self, txn, name, database_engine):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
+ object.__setattr__(self, "database_engine", database_engine)
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -142,7 +140,7 @@ class LoggingTransaction(object):
# TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
- sql = _convert_param_style(sql)
+ sql = self.database_engine.convert_param_style(sql)
try:
if args and args[0]:
@@ -227,9 +225,14 @@ class SQLBaseStore(object):
self._get_event_cache = LruCache(hs.config.event_cache_size)
+ self.database_engine = hs.database_engine
+
# Pretend the getEventCache is just another named cache
caches_by_name["*getEvent*"] = self._get_event_cache
+ self._next_stream_id_lock = threading.Lock()
+ self._next_stream_id = int(hs.get_clock().time_msec()) * 1000
+
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -281,7 +284,10 @@ class SQLBaseStore(object):
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name)
try:
- return func(LoggingTransaction(txn, name), *args, **kwargs)
+ return func(
+ LoggingTransaction(txn, name, self.database_engine),
+ *args, **kwargs
+ )
except:
logger.exception("[TXN FAIL] {%s}", name)
raise
@@ -588,7 +594,7 @@ class SQLBaseStore(object):
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
- " AND ".join("%s = ?" % (k) for k in keyvalues)
+ " AND ".join("%s = ?" % (k,) for k in keyvalues)
)
txn.execute(select_sql, keyvalues.values())
@@ -836,6 +842,12 @@ class SQLBaseStore(object):
result = txn.fetchone()
return result[0] if result else None
+ def get_next_stream_id(self):
+ with self._next_stream_id_lock:
+ i = self._next_stream_id
+ self._next_stream_id += 1
+ return i
+
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
|