summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/_base.py26
1 files changed, 23 insertions, 3 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index c328b5274c..7f5477dee5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -31,6 +31,7 @@ import functools
 import simplejson as json
 import sys
 import time
+import threading
 
 
 logger = logging.getLogger(__name__)
@@ -68,9 +69,20 @@ class Cache(object):
 
         self.name = name
         self.keylen = keylen
-
+        self.sequence = 0
+        self.thread = None
         caches_by_name[name] = self.cache
 
+    def check_thread(self):
+        expected_thread = self.thread
+        if expected_thread is None:
+            self.thread = threading.current_thread()
+        else:
+            if expected_thread is not threading.current_thread():
+                raise ValueError(
+                    "Cache objects can only be accessed from the main thread"
+                )
+
     def get(self, *keyargs):
         if len(keyargs) != self.keylen:
             raise ValueError("Expected a key to have %d items", self.keylen)
@@ -82,6 +94,11 @@ class Cache(object):
         cache_counter.inc_misses(self.name)
         raise KeyError()
 
+    def update(self, sequence, *args):
+        self.check_thread()
+        if self.sequence == sequence:
+            self.prefill(*args)
+
     def prefill(self, *args):  # because I can't  *keyargs, value
         keyargs = args[:-1]
         value = args[-1]
@@ -96,9 +113,10 @@ class Cache(object):
         self.cache[keyargs] = value
 
     def invalidate(self, *keyargs):
+        self.check_thread()
         if len(keyargs) != self.keylen:
             raise ValueError("Expected a key to have %d items", self.keylen)
-
+        self.sequence += 1
         self.cache.pop(keyargs, None)
 
 
@@ -130,9 +148,11 @@ def cached(max_entries=1000, num_args=1, lru=False):
             try:
                 defer.returnValue(cache.get(*keyargs))
             except KeyError:
+                sequence = cache.sequence
+
                 ret = yield orig(self, *keyargs)
 
-                cache.prefill(*keyargs + (ret,))
+                cache.update(sequence, *keyargs + (ret,))
 
                 defer.returnValue(ret)