diff options
author | Mark Haines <mark.haines@matrix.org> | 2015-05-05 14:08:03 +0100 |
---|---|---|
committer | Mark Haines <mark.haines@matrix.org> | 2015-05-05 14:13:50 +0100 |
commit | 261d809a4779b03c81ada52ed3893b2ad8782a96 (patch) | |
tree | 8c3ae56d3df1c11d711816e5764eb21caf23fc18 /synapse | |
parent | Correctly name transaction (diff) | |
download | synapse-261d809a4779b03c81ada52ed3893b2ad8782a96.tar.xz |
Sequence the modifications to the cache so that selects don't race with inserts
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/storage/_base.py | 26 |
1 files changed, 23 insertions, 3 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index c328b5274c..7f5477dee5 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -31,6 +31,7 @@ import functools import simplejson as json import sys import time +import threading logger = logging.getLogger(__name__) @@ -68,9 +69,20 @@ class Cache(object): self.name = name self.keylen = keylen - + self.sequence = 0 + self.thread = None caches_by_name[name] = self.cache + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + def get(self, *keyargs): if len(keyargs) != self.keylen: raise ValueError("Expected a key to have %d items", self.keylen) @@ -82,6 +94,11 @@ class Cache(object): cache_counter.inc_misses(self.name) raise KeyError() + def update(self, sequence, *args): + self.check_thread() + if self.sequence == sequence: + self.prefill(*args) + def prefill(self, *args): # because I can't *keyargs, value keyargs = args[:-1] value = args[-1] @@ -96,9 +113,10 @@ class Cache(object): self.cache[keyargs] = value def invalidate(self, *keyargs): + self.check_thread() if len(keyargs) != self.keylen: raise ValueError("Expected a key to have %d items", self.keylen) - + self.sequence += 1 self.cache.pop(keyargs, None) @@ -130,9 +148,11 @@ def cached(max_entries=1000, num_args=1, lru=False): try: defer.returnValue(cache.get(*keyargs)) except KeyError: + sequence = cache.sequence + ret = yield orig(self, *keyargs) - cache.prefill(*keyargs + (ret,)) + cache.update(sequence, *keyargs + (ret,)) defer.returnValue(ret) |