summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/5749.misc1
-rw-r--r--synapse/util/caches/descriptors.py74
-rw-r--r--tests/util/caches/test_descriptors.py90
3 files changed, 130 insertions, 35 deletions
diff --git a/changelog.d/5749.misc b/changelog.d/5749.misc
new file mode 100644
index 0000000000..48dd61f461
--- /dev/null
+++ b/changelog.d/5749.misc
@@ -0,0 +1 @@
+Fix some error cases in the caching layer.
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 7e69cf55fb..43f66ec4be 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -19,8 +19,7 @@ import logging
 import threading
 from collections import namedtuple
 
-import six
-from six import itervalues, string_types
+from six import itervalues
 
 from prometheus_client import Gauge
 
@@ -32,7 +31,6 @@ from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.stringutils import to_ascii
 
 from . import register_cache
 
@@ -124,7 +122,7 @@ class Cache(object):
             update_metrics (bool): whether to update the cache hit rate metrics
 
         Returns:
-            Either a Deferred or the raw result
+            Either an ObservableDeferred or the raw result
         """
         callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _CacheSentinel)
@@ -148,9 +146,14 @@ class Cache(object):
             return default
 
     def set(self, key, value, callback=None):
+        if not isinstance(value, defer.Deferred):
+            raise TypeError("not a Deferred")
+
         callbacks = [callback] if callback else []
         self.check_thread()
-        entry = CacheEntry(deferred=value, callbacks=callbacks)
+        observable = ObservableDeferred(value, consumeErrors=True)
+        observer = defer.maybeDeferred(observable.observe)
+        entry = CacheEntry(deferred=observable, callbacks=callbacks)
 
         existing_entry = self._pending_deferred_cache.pop(key, None)
         if existing_entry:
@@ -158,20 +161,31 @@ class Cache(object):
 
         self._pending_deferred_cache[key] = entry
 
-        def shuffle(result):
+        def compare_and_pop():
+            """Check if our entry is still the one in _pending_deferred_cache, and
+            if so, pop it.
+
+            Returns true if the entries matched.
+            """
             existing_entry = self._pending_deferred_cache.pop(key, None)
             if existing_entry is entry:
+                return True
+
+            # oops, the _pending_deferred_cache has been updated since
+            # we started our query, so we are out of date.
+            #
+            # Better put back whatever we took out. (We do it this way
+            # round, rather than peeking into the _pending_deferred_cache
+            # and then removing on a match, to make the common case faster)
+            if existing_entry is not None:
+                self._pending_deferred_cache[key] = existing_entry
+
+            return False
+
+        def cb(result):
+            if compare_and_pop():
                 self.cache.set(key, result, entry.callbacks)
             else:
-                # oops, the _pending_deferred_cache has been updated since
-                # we started our query, so we are out of date.
-                #
-                # Better put back whatever we took out. (We do it this way
-                # round, rather than peeking into the _pending_deferred_cache
-                # and then removing on a match, to make the common case faster)
-                if existing_entry is not None:
-                    self._pending_deferred_cache[key] = existing_entry
-
                 # we're not going to put this entry into the cache, so need
                 # to make sure that the invalidation callbacks are called.
                 # That was probably done when _pending_deferred_cache was
@@ -179,9 +193,16 @@ class Cache(object):
                 # `invalidate` being previously called, in which case it may
                 # not have been. Either way, let's double-check now.
                 entry.invalidate()
-            return result
 
-        entry.deferred.addCallback(shuffle)
+        def eb(_fail):
+            compare_and_pop()
+            entry.invalidate()
+
+        # once the deferred completes, we can move the entry from the
+        # _pending_deferred_cache to the real cache.
+        #
+        observer.addCallbacks(cb, eb)
+        return observable
 
     def prefill(self, key, value, callback=None):
         callbacks = [callback] if callback else []
@@ -414,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
 
                 ret.addErrback(onErr)
 
-                # If our cache_key is a string on py2, try to convert to ascii
-                # to save a bit of space in large caches. Py3 does this
-                # internally automatically.
-                if six.PY2 and isinstance(cache_key, string_types):
-                    cache_key = to_ascii(cache_key)
-
-                result_d = ObservableDeferred(ret, consumeErrors=True)
-                cache.set(cache_key, result_d, callback=invalidate_callback)
+                result_d = cache.set(cache_key, ret, callback=invalidate_callback)
                 observer = result_d.observe()
 
-            if isinstance(observer, defer.Deferred):
-                return make_deferred_yieldable(observer)
-            else:
-                return observer
+            return make_deferred_yieldable(observer)
 
         if self.num_args == 1:
             wrapped.invalidate = lambda key: cache.invalidate(key[0])
@@ -543,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
                     missing.add(arg)
 
             if missing:
-                # we need an observable deferred for each entry in the list,
+                # we need a deferred for each entry in the list,
                 # which we put in the cache. Each deferred resolves with the
                 # relevant result for that key.
                 deferreds_map = {}
@@ -551,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
                     deferred = defer.Deferred()
                     deferreds_map[arg] = deferred
                     key = arg_to_cache_key(arg)
-                    observable = ObservableDeferred(deferred)
-                    cache.set(key, observable, callback=invalidate_callback)
+                    cache.set(key, deferred, callback=invalidate_callback)
 
                 def complete_all(res):
                     # the wrapped function has completed. It returns a
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 56320bbaf9..5713870f48 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -27,6 +27,7 @@ from synapse.logging.context import (
     make_deferred_yieldable,
 )
 from synapse.util.caches import descriptors
+from synapse.util.caches.descriptors import cached
 
 from tests import unittest
 
@@ -55,12 +56,15 @@ class CacheTestCase(unittest.TestCase):
         d2 = defer.Deferred()
         cache.set("key2", d2, partial(record_callback, 1))
 
-        # lookup should return the deferreds
-        self.assertIs(cache.get("key1"), d1)
-        self.assertIs(cache.get("key2"), d2)
+        # lookup should return observable deferreds
+        self.assertFalse(cache.get("key1").has_called())
+        self.assertFalse(cache.get("key2").has_called())
 
         # let one of the lookups complete
         d2.callback("result2")
+
+        # for now at least, the cache will return real results rather than an
+        # observabledeferred
         self.assertEqual(cache.get("key2"), "result2")
 
         # now do the invalidation
@@ -146,6 +150,28 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(r, "chips")
         obj.mock.assert_not_called()
 
+    def test_cache_with_sync_exception(self):
+        """If the wrapped function throws synchronously, things should continue to work
+        """
+
+        class Cls(object):
+            @cached()
+            def fn(self, arg1):
+                raise SynapseError(100, "mai spoon iz too big!!1")
+
+        obj = Cls()
+
+        # this should fail immediately
+        d = obj.fn(1)
+        self.failureResultOf(d, SynapseError)
+
+        # ... leaving the cache empty
+        self.assertEqual(len(obj.fn.cache.cache), 0)
+
+        # and a second call should result in a second exception
+        d = obj.fn(1)
+        self.failureResultOf(d, SynapseError)
+
     def test_cache_logcontexts(self):
         """Check that logcontexts are set and restored correctly when
         using the cache."""
@@ -222,6 +248,9 @@ class DescriptorTestCase(unittest.TestCase):
 
                 self.assertEqual(LoggingContext.current_context(), c1)
 
+            # the cache should now be empty
+            self.assertEqual(len(obj.fn.cache.cache), 0)
+
         obj = Cls()
 
         # set off a deferred which will do a cache lookup
@@ -268,6 +297,61 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(r, "chips")
         obj.mock.assert_not_called()
 
+    def test_cache_iterable(self):
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached(iterable=True)
+            def fn(self, arg1, arg2):
+                return self.mock(arg1, arg2)
+
+        obj = Cls()
+
+        obj.mock.return_value = ["spam", "eggs"]
+        r = obj.fn(1, 2)
+        self.assertEqual(r, ["spam", "eggs"])
+        obj.mock.assert_called_once_with(1, 2)
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = ["chips"]
+        r = obj.fn(1, 3)
+        self.assertEqual(r, ["chips"])
+        obj.mock.assert_called_once_with(1, 3)
+        obj.mock.reset_mock()
+
+        # the two values should now be cached
+        self.assertEqual(len(obj.fn.cache.cache), 3)
+
+        r = obj.fn(1, 2)
+        self.assertEqual(r, ["spam", "eggs"])
+        r = obj.fn(1, 3)
+        self.assertEqual(r, ["chips"])
+        obj.mock.assert_not_called()
+
+    def test_cache_iterable_with_sync_exception(self):
+        """If the wrapped function throws synchronously, things should continue to work
+        """
+
+        class Cls(object):
+            @descriptors.cached(iterable=True)
+            def fn(self, arg1):
+                raise SynapseError(100, "mai spoon iz too big!!1")
+
+        obj = Cls()
+
+        # this should fail immediately
+        d = obj.fn(1)
+        self.failureResultOf(d, SynapseError)
+
+        # ... leaving the cache empty
+        self.assertEqual(len(obj.fn.cache.cache), 0)
+
+        # and a second call should result in a second exception
+        d = obj.fn(1)
+        self.failureResultOf(d, SynapseError)
+
 
 class CachedListDescriptorTestCase(unittest.TestCase):
     @defer.inlineCallbacks