diff options
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r-- | synapse/util/caches/descriptors.py | 64 |
1 files changed, 38 insertions, 26 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index bf3a66eae4..68285a7594 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,12 +40,11 @@ _CacheSentinel = object() class CacheEntry(object): __slots__ = [ - "deferred", "sequence", "callbacks", "invalidated" + "deferred", "callbacks", "invalidated" ] - def __init__(self, deferred, sequence, callbacks): + def __init__(self, deferred, callbacks): self.deferred = deferred - self.sequence = sequence self.callbacks = set(callbacks) self.invalidated = False @@ -62,7 +62,6 @@ class Cache(object): "max_entries", "name", "keylen", - "sequence", "thread", "metrics", "_pending_deferred_cache", @@ -80,7 +79,6 @@ class Cache(object): self.name = name self.keylen = keylen - self.sequence = 0 self.thread = None self.metrics = register_cache(name, self.cache) @@ -113,11 +111,10 @@ class Cache(object): callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _CacheSentinel) if val is not _CacheSentinel: - if val.sequence == self.sequence: - val.callbacks.update(callbacks) - if update_metrics: - self.metrics.inc_hits() - return val.deferred + val.callbacks.update(callbacks) + if update_metrics: + self.metrics.inc_hits() + return val.deferred val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) if val is not _CacheSentinel: @@ -137,12 +134,9 @@ class Cache(object): self.check_thread() entry = CacheEntry( deferred=value, - sequence=self.sequence, callbacks=callbacks, ) - entry.callbacks.update(callbacks) - existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: existing_entry.invalidate() @@ -150,13 +144,25 @@ class Cache(object): self._pending_deferred_cache[key] = entry def shuffle(result): - if self.sequence == entry.sequence: - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - self.cache.set(key, result, entry.callbacks) - else: - entry.invalidate() + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + self.cache.set(key, result, entry.callbacks) else: + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + # we're not going to put this entry into the cache, so need + # to make sure that the invalidation callbacks are called. + # That was probably done when _pending_deferred_cache was + # updated, but it's possible that `set` was called without + # `invalidate` being previously called, in which case it may + # not have been. Either way, let's double-check now. entry.invalidate() return result @@ -168,25 +174,29 @@ class Cache(object): def invalidate(self, key): self.check_thread() + self.cache.pop(key, None) - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, which will (a) stop it being returned + # for future queries and (b) stop it being persisted as a proper entry + # in self.cache. entry = self._pending_deferred_cache.pop(key, None) + + # run the invalidation callbacks now, rather than waiting for the + # deferred to resolve. if entry: entry.invalidate() - self.cache.pop(key, None) - def invalidate_many(self, key): self.check_thread() if not isinstance(key, tuple): raise TypeError( "The cache key must be a tuple not %r" % (type(key),) ) - self.sequence += 1 self.cache.del_multi(key) + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, as above entry_dict = self._pending_deferred_cache.pop(key, None) if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): @@ -194,8 +204,10 @@ class Cache(object): def invalidate_all(self): self.check_thread() - self.sequence += 1 self.cache.clear() + for entry in self._pending_deferred_cache.itervalues(): + entry.invalidate() + self._pending_deferred_cache.clear() class _CacheDescriptorBase(object): |