diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index c328b5274c..7f5477dee5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -31,6 +31,7 @@ import functools
import simplejson as json
import sys
import time
+import threading
logger = logging.getLogger(__name__)
@@ -68,9 +69,20 @@ class Cache(object):
self.name = name
self.keylen = keylen
-
+ self.sequence = 0
+ self.thread = None
caches_by_name[name] = self.cache
+ def check_thread(self):
+ expected_thread = self.thread
+ if expected_thread is None:
+ self.thread = threading.current_thread()
+ else:
+ if expected_thread is not threading.current_thread():
+ raise ValueError(
+ "Cache objects can only be accessed from the main thread"
+ )
+
def get(self, *keyargs):
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
@@ -82,6 +94,11 @@ class Cache(object):
cache_counter.inc_misses(self.name)
raise KeyError()
+ def update(self, sequence, *args):
+ self.check_thread()
+ if self.sequence == sequence:
+ self.prefill(*args)
+
def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
@@ -96,9 +113,10 @@ class Cache(object):
self.cache[keyargs] = value
def invalidate(self, *keyargs):
+ self.check_thread()
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
-
+ self.sequence += 1
self.cache.pop(keyargs, None)
@@ -130,9 +148,11 @@ def cached(max_entries=1000, num_args=1, lru=False):
try:
defer.returnValue(cache.get(*keyargs))
except KeyError:
+ sequence = cache.sequence
+
ret = yield orig(self, *keyargs)
- cache.prefill(*keyargs + (ret,))
+ cache.update(sequence, *keyargs + (ret,))
defer.returnValue(ret)
|