summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-04-01 14:12:33 +0100
committerErik Johnston <erik@matrix.org>2015-04-01 14:12:33 +0100
commit9236136f3a4f0d8119d4a6333f37378f8e259e4a (patch)
treec8e806fc815f9215501e409cf20471529cef4d15 /synapse/storage/_base.py
parentFix unicode database support (diff)
downloadsynapse-9236136f3a4f0d8119d4a6333f37378f8e259e4a.tar.xz
Make work in both Maria and SQLite. Fix tests
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 76ec3ee93f..047d100f46 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -29,6 +29,7 @@ import functools
 import simplejson as json
 import sys
 import time
+import threading
 
 
 logger = logging.getLogger(__name__)
@@ -118,19 +119,16 @@ def cached(max_entries=1000, num_args=1):
     return wrap
 
 
-def _convert_param_style(sql):
-    return sql.replace("?", "%s")
-
-
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
     method."""
-    __slots__ = ["txn", "name"]
+    __slots__ = ["txn", "name", "database_engine"]
 
-    def __init__(self, txn, name):
+    def __init__(self, txn, name, database_engine):
         object.__setattr__(self, "txn", txn)
         object.__setattr__(self, "name", name)
+        object.__setattr__(self, "database_engine", database_engine)
 
     def __getattr__(self, name):
         return getattr(self.txn, name)
@@ -142,7 +140,7 @@ class LoggingTransaction(object):
         # TODO(paul): Maybe use 'info' and 'debug' for values?
         sql_logger.debug("[SQL] {%s} %s", self.name, sql)
 
-        sql = _convert_param_style(sql)
+        sql = self.database_engine.convert_param_style(sql)
 
         try:
             if args and args[0]:
@@ -227,9 +225,14 @@ class SQLBaseStore(object):
 
         self._get_event_cache = LruCache(hs.config.event_cache_size)
 
+        self.database_engine = hs.database_engine
+
         # Pretend the getEventCache is just another named cache
         caches_by_name["*getEvent*"] = self._get_event_cache
 
+        self._next_stream_id_lock = threading.Lock()
+        self._next_stream_id = int(hs.get_clock().time_msec()) * 1000
+
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
 
@@ -281,7 +284,10 @@ class SQLBaseStore(object):
                 sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
                 transaction_logger.debug("[TXN START] {%s}", name)
                 try:
-                    return func(LoggingTransaction(txn, name), *args, **kwargs)
+                    return func(
+                        LoggingTransaction(txn, name, self.database_engine),
+                        *args, **kwargs
+                    )
                 except:
                     logger.exception("[TXN FAIL] {%s}", name)
                     raise
@@ -588,7 +594,7 @@ class SQLBaseStore(object):
         select_sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
             table,
-            " AND ".join("%s = ?" % (k) for k in keyvalues)
+            " AND ".join("%s = ?" % (k,) for k in keyvalues)
         )
 
         txn.execute(select_sql, keyvalues.values())
@@ -836,6 +842,12 @@ class SQLBaseStore(object):
         result = txn.fetchone()
         return result[0] if result else None
 
+    def get_next_stream_id(self):
+        with self._next_stream_id_lock:
+            i = self._next_stream_id
+            self._next_stream_id += 1
+            return i
+
 
 class _RollbackButIsFineException(Exception):
     """ This exception is used to rollback a transaction without implying