diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e01c61d08d..a2da3dd1b1 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -31,7 +31,9 @@ import functools
import simplejson as json
import sys
import time
+import threading
+DEBUG_CACHES = False
logger = logging.getLogger(__name__)
@@ -68,9 +70,20 @@ class Cache(object):
self.name = name
self.keylen = keylen
-
+ self.sequence = 0
+ self.thread = None
caches_by_name[name] = self.cache
+ def check_thread(self):
+ expected_thread = self.thread
+ if expected_thread is None:
+ self.thread = threading.current_thread()
+ else:
+ if expected_thread is not threading.current_thread():
+ raise ValueError(
+ "Cache objects can only be accessed from the main thread"
+ )
+
def get(self, *keyargs):
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
@@ -82,6 +95,13 @@ class Cache(object):
cache_counter.inc_misses(self.name)
raise KeyError()
+ def update(self, sequence, *args):
+ self.check_thread()
+ if self.sequence == sequence:
+ # Only update the cache if the caches sequence number matches the
+ # number that the cache had before the SELECT was started (SYN-369)
+ self.prefill(*args)
+
def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
@@ -96,9 +116,12 @@ class Cache(object):
self.cache[keyargs] = value
def invalidate(self, *keyargs):
+ self.check_thread()
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
-
+ # Increment the sequence number so that any SELECT statements that
+ # raced with the INSERT don't update the cache (SYN-369)
+ self.sequence += 1
self.cache.pop(keyargs, None)
@@ -128,11 +151,26 @@ def cached(max_entries=1000, num_args=1, lru=False):
@defer.inlineCallbacks
def wrapped(self, *keyargs):
try:
- defer.returnValue(cache.get(*keyargs))
+ cached_result = cache.get(*keyargs)
+ if DEBUG_CACHES:
+ actual_result = yield orig(self, *keyargs)
+ if actual_result != cached_result:
+ logger.error(
+ "Stale cache entry %s%r: cached: %r, actual %r",
+ orig.__name__, keyargs,
+ cached_result, actual_result,
+ )
+ raise ValueError("Stale cache entry")
+ defer.returnValue(cached_result)
except KeyError:
+ # Get the sequence number of the cache before reading from the
+ # database so that we can tell if the cache is invalidated
+ # while the SELECT is executing (SYN-369)
+ sequence = cache.sequence
+
ret = yield orig(self, *keyargs)
- cache.prefill(*keyargs + (ret,))
+ cache.update(sequence, *keyargs + (ret,))
defer.returnValue(ret)
@@ -147,12 +185,20 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
- __slots__ = ["txn", "name", "database_engine"]
+ __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
- def __init__(self, txn, name, database_engine):
+ def __init__(self, txn, name, database_engine, after_callbacks):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
+ object.__setattr__(self, "after_callbacks", after_callbacks)
+
+ def call_after(self, callback, *args):
+ """Call the given callback on the main twisted thread after the
+ transaction has finished. Used to invalidate the caches on the
+ correct thread.
+ """
+ self.after_callbacks.append((callback, args))
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -294,6 +340,8 @@ class SQLBaseStore(object):
start_time = time.time() * 1000
+ after_callbacks = []
+
def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
if self.database_engine.is_connection_closed(conn):
@@ -318,10 +366,10 @@ class SQLBaseStore(object):
while True:
try:
txn = conn.cursor()
- return func(
- LoggingTransaction(txn, name, self.database_engine),
- *args, **kwargs
+ txn = LoggingTransaction(
+ txn, name, self.database_engine, after_callbacks
)
+ return func(txn, *args, **kwargs)
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
@@ -370,6 +418,8 @@ class SQLBaseStore(object):
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
+ for after_callback, after_args in after_callbacks:
+ after_callback(*after_args)
defer.returnValue(result)
def cursor_to_dict(self, cursor):
|