summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py134
1 files changed, 108 insertions, 26 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index c328b5274c..ee5587c721 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -31,7 +31,9 @@ import functools
 import simplejson as json
 import sys
 import time
+import threading
 
+DEBUG_CACHES = False
 
 logger = logging.getLogger(__name__)
 
@@ -68,9 +70,20 @@ class Cache(object):
 
         self.name = name
         self.keylen = keylen
-
+        self.sequence = 0
+        self.thread = None
         caches_by_name[name] = self.cache
 
+    def check_thread(self):
+        expected_thread = self.thread
+        if expected_thread is None:
+            self.thread = threading.current_thread()
+        else:
+            if expected_thread is not threading.current_thread():
+                raise ValueError(
+                    "Cache objects can only be accessed from the main thread"
+                )
+
     def get(self, *keyargs):
         if len(keyargs) != self.keylen:
             raise ValueError("Expected a key to have %d items", self.keylen)
@@ -82,6 +95,13 @@ class Cache(object):
         cache_counter.inc_misses(self.name)
         raise KeyError()
 
+    def update(self, sequence, *args):
+        self.check_thread()
+        if self.sequence == sequence:
+            # Only update the cache if the caches sequence number matches the
+            # number that the cache had before the SELECT was started (SYN-369)
+            self.prefill(*args)
+
     def prefill(self, *args):  # because I can't  *keyargs, value
         keyargs = args[:-1]
         value = args[-1]
@@ -96,9 +116,12 @@ class Cache(object):
         self.cache[keyargs] = value
 
     def invalidate(self, *keyargs):
+        self.check_thread()
         if len(keyargs) != self.keylen:
             raise ValueError("Expected a key to have %d items", self.keylen)
-
+        # Increment the sequence number so that any SELECT statements that
+        # raced with the INSERT don't update the cache (SYN-369)
+        self.sequence += 1
         self.cache.pop(keyargs, None)
 
 
@@ -128,11 +151,26 @@ def cached(max_entries=1000, num_args=1, lru=False):
         @defer.inlineCallbacks
         def wrapped(self, *keyargs):
             try:
-                defer.returnValue(cache.get(*keyargs))
+                cached_result = cache.get(*keyargs)
+                if DEBUG_CACHES:
+                    actual_result = yield orig(self, *keyargs)
+                    if actual_result != cached_result:
+                        logger.error(
+                            "Stale cache entry %s%r: cached: %r, actual %r",
+                            orig.__name__, keyargs,
+                            cached_result, actual_result,
+                        )
+                        raise ValueError("Stale cache entry")
+                defer.returnValue(cached_result)
             except KeyError:
+                # Get the sequence number of the cache before reading from the
+                # database so that we can tell if the cache is invalidated
+                # while the SELECT is executing (SYN-369)
+                sequence = cache.sequence
+
                 ret = yield orig(self, *keyargs)
 
-                cache.prefill(*keyargs + (ret,))
+                cache.update(sequence, *keyargs + (ret,))
 
                 defer.returnValue(ret)
 
@@ -147,12 +185,20 @@ class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
     method."""
-    __slots__ = ["txn", "name", "database_engine"]
+    __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
 
-    def __init__(self, txn, name, database_engine):
+    def __init__(self, txn, name, database_engine, after_callbacks):
         object.__setattr__(self, "txn", txn)
         object.__setattr__(self, "name", name)
         object.__setattr__(self, "database_engine", database_engine)
+        object.__setattr__(self, "after_callbacks", after_callbacks)
+
+    def call_after(self, callback, *args):
+        """Call the given callback on the main twisted thread after the
+        transaction has finished. Used to invalidate the caches on the
+        correct thread.
+        """
+        self.after_callbacks.append((callback, args))
 
     def __getattr__(self, name):
         return getattr(self.txn, name)
@@ -160,22 +206,23 @@ class LoggingTransaction(object):
     def __setattr__(self, name, value):
         setattr(self.txn, name, value)
 
-    def execute(self, sql, *args, **kwargs):
+    def execute(self, sql, *args):
+        self._do_execute(self.txn.execute, sql, *args)
+
+    def executemany(self, sql, *args):
+        self._do_execute(self.txn.executemany, sql, *args)
+
+    def _do_execute(self, func, sql, *args):
         # TODO(paul): Maybe use 'info' and 'debug' for values?
         sql_logger.debug("[SQL] {%s} %s", self.name, sql)
 
         sql = self.database_engine.convert_param_style(sql)
 
-        if args and args[0]:
-            args = list(args)
-            args[0] = [
-                self.database_engine.encode_parameter(a) for a in args[0]
-            ]
+        if args:
             try:
                 sql_logger.debug(
-                    "[SQL values] {%s} " + ", ".join(("<%r>",) * len(args[0])),
-                    self.name,
-                    *args[0]
+                    "[SQL values] {%s} %r",
+                    self.name, args[0]
                 )
             except:
                 # Don't let logging failures stop SQL from working
@@ -184,8 +231,8 @@ class LoggingTransaction(object):
         start = time.time() * 1000
 
         try:
-            return self.txn.execute(
-                sql, *args, **kwargs
+            return func(
+                sql, *args
             )
         except Exception as e:
             logger.debug("[SQL FAIL] {%s} %s", self.name, e)
@@ -298,6 +345,8 @@ class SQLBaseStore(object):
 
         start_time = time.time() * 1000
 
+        after_callbacks = []
+
         def inner_func(conn, *args, **kwargs):
             with LoggingContext("runInteraction") as context:
                 if self.database_engine.is_connection_closed(conn):
@@ -322,10 +371,10 @@ class SQLBaseStore(object):
                     while True:
                         try:
                             txn = conn.cursor()
-                            return func(
-                                LoggingTransaction(txn, name, self.database_engine),
-                                *args, **kwargs
+                            txn = LoggingTransaction(
+                                txn, name, self.database_engine, after_callbacks
                             )
+                            return func(txn, *args, **kwargs)
                         except self.database_engine.module.OperationalError as e:
                             # This can happen if the database disappears mid
                             # transaction.
@@ -374,6 +423,8 @@ class SQLBaseStore(object):
             result = yield self._db_pool.runWithConnection(
                 inner_func, *args, **kwargs
             )
+        for after_callback, after_args in after_callbacks:
+            after_callback(*after_args)
         defer.returnValue(result)
 
     def cursor_to_dict(self, cursor):
@@ -438,18 +489,49 @@ class SQLBaseStore(object):
 
     @log_function
     def _simple_insert_txn(self, txn, table, values):
+        keys, vals = zip(*values.items())
+
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
             table,
-            ", ".join(k for k in values),
-            ", ".join("?" for k in values)
+            ", ".join(k for k in keys),
+            ", ".join("?" for _ in keys)
         )
 
-        logger.debug(
-            "[SQL] %s Args=%s",
-            sql, values.values(),
+        txn.execute(sql, vals)
+
+    def _simple_insert_many_txn(self, txn, table, values):
+        if not values:
+            return
+
+        # This is a *slight* abomination to get a list of tuples of key names
+        # and a list of tuples of value names.
+        #
+        # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
+        #         => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
+        #
+        # The sort is to ensure that we don't rely on dictionary iteration
+        # order.
+        keys, vals = zip(*[
+            zip(
+                *(sorted(i.items(), key=lambda kv: kv[0]))
+            )
+            for i in values
+            if i
+        ])
+
+        for k in keys:
+            if k != keys[0]:
+                raise RuntimeError(
+                    "All items must have the same keys"
+                )
+
+        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+            table,
+            ", ".join(k for k in keys[0]),
+            ", ".join("?" for _ in keys[0])
         )
 
-        txn.execute(sql, values.values())
+        txn.executemany(sql, vals)
 
     def _simple_upsert(self, table, keyvalues, values,
                        insertion_values={}, desc="_simple_upsert", lock=True):