summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/push/push_tools.py7
-rw-r--r--synapse/util/async.py7
-rw-r--r--synapse/util/caches/descriptors.py42
-rw-r--r--synapse/visibility.py3
-rw-r--r--tests/storage/test__base.py2
-rw-r--r--tests/util/caches/test_descriptors.py38
-rw-r--r--tests/util/test_snapshot_cache.py4
7 files changed, 87 insertions, 16 deletions
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 287df94b4f..6835f54e97 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,15 +17,12 @@ from twisted.internet import defer
 from synapse.push.presentable_names import (
     calculate_room_name, name_from_member_event
 )
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 
 
 @defer.inlineCallbacks
 def get_badge_count(store, user_id):
-    invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
-        preserve_fn(store.get_invited_rooms_for_user)(user_id),
-        preserve_fn(store.get_rooms_for_user)(user_id),
-    ], consumeErrors=True))
+    invites = yield store.get_invited_rooms_for_user(user_id)
+    joins = yield store.get_rooms_for_user(user_id)
 
     my_receipts_by_room = yield store.get_receipts_for_user(
         user_id, "m.read",
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 35380bf8ed..1453faf0ef 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -89,6 +89,11 @@ class ObservableDeferred(object):
         deferred.addCallbacks(callback, errback)
 
     def observe(self):
+        """Observe the underlying deferred.
+
+        Can return either a deferred if the underlying deferred is still pending
+        (or has failed), or the actual value. Callers may need to use maybeDeferred.
+        """
         if not self._result:
             d = defer.Deferred()
 
@@ -101,7 +106,7 @@ class ObservableDeferred(object):
             return d
         else:
             success, res = self._result
-            return defer.succeed(res) if success else defer.fail(res)
+            return res if success else defer.fail(res)
 
     def observers(self):
         return self._observers
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5c30ed235d..9d0d0be1f9 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -224,8 +224,20 @@ class _CacheDescriptorBase(object):
             )
 
         self.num_args = num_args
+
+        # list of the names of the args used as the cache key
         self.arg_names = all_args[1:num_args + 1]
 
+        # self.arg_defaults is a map of arg name to its default value for each
+        # argument that has a default value
+        if arg_spec.defaults:
+            self.arg_defaults = dict(zip(
+                all_args[-len(arg_spec.defaults):],
+                arg_spec.defaults
+            ))
+        else:
+            self.arg_defaults = {}
+
         if "cache_context" in self.arg_names:
             raise Exception(
                 "cache_context arg cannot be included among the cache keys"
@@ -289,18 +301,31 @@ class CacheDescriptor(_CacheDescriptorBase):
             iterable=self.iterable,
         )
 
+        def get_cache_key(args, kwargs):
+            """Given some args/kwargs return a generator that resolves into
+            the cache_key.
+
+            We loop through each arg name, looking up if its in the `kwargs`,
+            otherwise using the next argument in `args`. If there are no more
+            args then we try looking the arg name up in the defaults
+            """
+            pos = 0
+            for nm in self.arg_names:
+                if nm in kwargs:
+                    yield kwargs[nm]
+                elif pos < len(args):
+                    yield args[pos]
+                    pos += 1
+                else:
+                    yield self.arg_defaults[nm]
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
-            # Add temp cache_context so inspect.getcallargs doesn't explode
-            if self.add_cache_context:
-                kwargs["cache_context"] = None
-
-            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+            cache_key = tuple(get_cache_key(args, kwargs))
 
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
@@ -341,7 +366,10 @@ class CacheDescriptor(_CacheDescriptorBase):
                 cache.set(cache_key, result_d, callback=invalidate_callback)
                 observer = result_d.observe()
 
-            return logcontext.make_deferred_yieldable(observer)
+            if isinstance(observer, defer.Deferred):
+                return logcontext.make_deferred_yieldable(observer)
+            else:
+                return observer
 
         wrapped.invalidate = cache.invalidate
         wrapped.invalidate_all = cache.invalidate_all
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 31659156ae..c4dd9ae2c7 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
         events ([synapse.events.EventBase]): list of events to filter
     """
     forgotten = yield preserve_context_over_deferred(defer.gatherResults([
-        preserve_fn(store.who_forgot_in_room)(
+        defer.maybeDeferred(
+            preserve_fn(store.who_forgot_in_room),
             room_id,
         )
         for room_id in frozenset(e.room_id for e in events)
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8361dd8cee..281eb16254 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -199,7 +199,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
         a.func.prefill(("foo",), ObservableDeferred(d))
 
-        self.assertEquals(a.func("foo").result, d.result)
+        self.assertEquals(a.func("foo"), d.result)
         self.assertEquals(callcount[0], 0)
 
     @defer.inlineCallbacks
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 4414e86771..3f14ab503f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -175,3 +175,41 @@ class DescriptorTestCase(unittest.TestCase):
                          logcontext.LoggingContext.sentinel)
 
         return d1
+
+    @defer.inlineCallbacks
+    def test_cache_default_args(self):
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2=2, arg3=3):
+                return self.mock(arg1, arg2, arg3)
+
+        obj = Cls()
+
+        obj.mock.return_value = 'fish'
+        r = yield obj.fn(1, 2, 3)
+        self.assertEqual(r, 'fish')
+        obj.mock.assert_called_once_with(1, 2, 3)
+        obj.mock.reset_mock()
+
+        # a call with same params shouldn't call the mock again
+        r = yield obj.fn(1, 2)
+        self.assertEqual(r, 'fish')
+        obj.mock.assert_not_called()
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = 'chips'
+        r = yield obj.fn(2, 3)
+        self.assertEqual(r, 'chips')
+        obj.mock.assert_called_once_with(2, 3, 3)
+        obj.mock.reset_mock()
+
+        # the two values should now be cached
+        r = yield obj.fn(1, 2)
+        self.assertEqual(r, 'fish')
+        r = yield obj.fn(2, 3)
+        self.assertEqual(r, 'chips')
+        obj.mock.assert_not_called()
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
index 7e289715ba..d3a8630c2f 100644
--- a/tests/util/test_snapshot_cache.py
+++ b/tests/util/test_snapshot_cache.py
@@ -53,7 +53,9 @@ class SnapshotCacheTestCase(unittest.TestCase):
         # before the cache expires returns a resolved deferred.
         get_result_at_11 = self.cache.get(11, "key")
         self.assertIsNotNone(get_result_at_11)
-        self.assertTrue(get_result_at_11.called)
+        if isinstance(get_result_at_11, Deferred):
+            # The cache may return the actual result rather than a deferred
+            self.assertTrue(get_result_at_11.called)
 
         # Check that getting the key after the deferred has resolved
         # after the cache expires returns None