diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 998de70d29..19595df422 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -189,7 +189,55 @@ class Cache(object):
self.cache.clear()
-class CacheDescriptor(object):
+class _CacheDescriptorBase(object):
+ def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+ self.orig = orig
+
+ if inlineCallbacks:
+ self.function_to_call = defer.inlineCallbacks(orig)
+ else:
+ self.function_to_call = orig
+
+ arg_spec = inspect.getargspec(orig)
+ all_args = arg_spec.args
+
+ if "cache_context" in all_args:
+ if not cache_context:
+ raise ValueError(
+ "Cannot have a 'cache_context' arg without setting"
+ " cache_context=True"
+ )
+ elif cache_context:
+ raise ValueError(
+ "Cannot have cache_context=True without having an arg"
+ " named `cache_context`"
+ )
+
+ if num_args is None:
+ num_args = len(all_args) - 1
+ if cache_context:
+ num_args -= 1
+
+ if len(all_args) < num_args + 1:
+ raise Exception(
+ "Not enough explicit positional arguments to key off for %r: "
+ "got %i args, but wanted %i. (@cached cannot key off *args or "
+ "**kwargs)"
+ % (orig.__name__, len(all_args), num_args)
+ )
+
+ self.num_args = num_args
+ self.arg_names = all_args[1:num_args + 1]
+
+ if "cache_context" in self.arg_names:
+ raise Exception(
+ "cache_context arg cannot be included among the cache keys"
+ )
+
+ self.add_cache_context = cache_context
+
+
+class CacheDescriptor(_CacheDescriptorBase):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
@@ -217,52 +265,24 @@ class CacheDescriptor(object):
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
defer.returnValue(r1 + r2)
+ Args:
+ num_args (int): number of positional arguments (excluding ``self`` and
+ ``cache_context``) to use as cache keys. Defaults to all named
+ args of the function.
"""
- def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
+ def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
inlineCallbacks=False, cache_context=False, iterable=False):
- max_entries = int(max_entries * CACHE_SIZE_FACTOR)
- self.orig = orig
+ super(CacheDescriptor, self).__init__(
+ orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
+ cache_context=cache_context)
- if inlineCallbacks:
- self.function_to_call = defer.inlineCallbacks(orig)
- else:
- self.function_to_call = orig
+ max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.max_entries = max_entries
- self.num_args = num_args
self.tree = tree
-
self.iterable = iterable
- all_args = inspect.getargspec(orig)
- self.arg_names = all_args.args[1:num_args + 1]
-
- if "cache_context" in all_args.args:
- if not cache_context:
- raise ValueError(
- "Cannot have a 'cache_context' arg without setting"
- " cache_context=True"
- )
- try:
- self.arg_names.remove("cache_context")
- except ValueError:
- pass
- elif cache_context:
- raise ValueError(
- "Cannot have cache_context=True without having an arg"
- " named `cache_context`"
- )
-
- self.add_cache_context = cache_context
-
- if len(self.arg_names) < self.num_args:
- raise Exception(
- "Not enough explicit positional arguments to key off of for %r."
- " (@cached cannot key off of *args or **kwargs)"
- % (orig.__name__,)
- )
-
def __get__(self, obj, objtype=None):
cache = Cache(
name=self.orig.__name__,
@@ -338,48 +358,36 @@ class CacheDescriptor(object):
return wrapped
-class CacheListDescriptor(object):
+class CacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped fucntion.
"""
- def __init__(self, orig, cached_method_name, list_name, num_args=1,
+ def __init__(self, orig, cached_method_name, list_name, num_args=None,
inlineCallbacks=False):
"""
Args:
orig (function)
- method_name (str); The name of the chached method.
+ cached_method_name (str): The name of the chached method.
list_name (str): Name of the argument which is the bulk lookup list
- num_args (int)
+ num_args (int): number of positional arguments (excluding ``self``,
+ but including list_name) to use as cache keys. Defaults to all
+ named args of the function.
inlineCallbacks (bool): Whether orig is a generator that should
be wrapped by defer.inlineCallbacks
"""
- self.orig = orig
+ super(CacheListDescriptor, self).__init__(
+ orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
- if inlineCallbacks:
- self.function_to_call = defer.inlineCallbacks(orig)
- else:
- self.function_to_call = orig
-
- self.num_args = num_args
self.list_name = list_name
- self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name)
-
self.cached_method_name = cached_method_name
self.sentinel = object()
- if len(self.arg_names) < self.num_args:
- raise Exception(
- "Not enough explicit positional arguments to key off of for %r."
- " (@cached cannot key off of *args or **kwars)"
- % (orig.__name__,)
- )
-
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
@@ -487,7 +495,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key)
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor(
orig,
@@ -499,8 +507,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
- iterable=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
+ cache_context=False, iterable=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@@ -512,7 +520,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
)
-def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
@@ -525,7 +533,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
cache (Cache): The underlying cache to use.
list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache.
- num_args (int): Number of arguments to use as the key in the cache.
+ num_args (int): Number of arguments to use as the key in the cache
+ (including list_name). Defaults to all named parameters.
inlineCallbacks (bool): Should the function be wrapped in an
`defer.inlineCallbacks`?
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
new file mode 100644
index 0000000000..0be790d8f8
--- /dev/null
+++ b/tests/storage/test_keys.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import signedjson.key
+from twisted.internet import defer
+
+import tests.unittest
+import tests.utils
+
+
+class KeyStoreTestCase(tests.unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(KeyStoreTestCase, self).__init__(*args, **kwargs)
+ self.store = None # type: synapse.storage.keys.KeyStore
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield tests.utils.setup_test_homeserver()
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def test_get_server_verify_keys(self):
+ key1 = signedjson.key.decode_verify_key_base64(
+ "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
+ )
+ key2 = signedjson.key.decode_verify_key_base64(
+ "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
+ )
+ yield self.store.store_server_verify_key(
+ "server1", "from_server", 0, key1
+ )
+ yield self.store.store_server_verify_key(
+ "server1", "from_server", 0, key2
+ )
+
+ res = yield self.store.get_server_verify_keys(
+ "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"])
+
+ self.assertEqual(len(res.keys()), 2)
+ self.assertEqual(res["ed25519:key1"].version, "key1")
+ self.assertEqual(res["ed25519:key2"].version, "key2")
diff --git a/tests/util/caches/__init__.py b/tests/util/caches/__init__.py
new file mode 100644
index 0000000000..451dae3b6c
--- /dev/null
+++ b/tests/util/caches/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
new file mode 100644
index 0000000000..419281054d
--- /dev/null
+++ b/tests/util/caches/test_descriptors.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import mock
+from twisted.internet import defer
+from synapse.util.caches import descriptors
+from tests import unittest
+
+
+class DescriptorTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def test_cache(self):
+ class Cls(object):
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached()
+ def fn(self, arg1, arg2):
+ return self.mock(arg1, arg2)
+
+ obj = Cls()
+
+ obj.mock.return_value = 'fish'
+ r = yield obj.fn(1, 2)
+ self.assertEqual(r, 'fish')
+ obj.mock.assert_called_once_with(1, 2)
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = 'chips'
+ r = yield obj.fn(1, 3)
+ self.assertEqual(r, 'chips')
+ obj.mock.assert_called_once_with(1, 3)
+ obj.mock.reset_mock()
+
+ # the two values should now be cached
+ r = yield obj.fn(1, 2)
+ self.assertEqual(r, 'fish')
+ r = yield obj.fn(1, 3)
+ self.assertEqual(r, 'chips')
+ obj.mock.assert_not_called()
+
+ @defer.inlineCallbacks
+ def test_cache_num_args(self):
+ """Only the first num_args arguments should matter to the cache"""
+
+ class Cls(object):
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached(num_args=1)
+ def fn(self, arg1, arg2):
+ return self.mock(arg1, arg2)
+
+ obj = Cls()
+ obj.mock.return_value = 'fish'
+ r = yield obj.fn(1, 2)
+ self.assertEqual(r, 'fish')
+ obj.mock.assert_called_once_with(1, 2)
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = 'chips'
+ r = yield obj.fn(2, 3)
+ self.assertEqual(r, 'chips')
+ obj.mock.assert_called_once_with(2, 3)
+ obj.mock.reset_mock()
+
+ # the two values should now be cached; we should be able to vary
+ # the second argument and still get the cached result.
+ r = yield obj.fn(1, 4)
+ self.assertEqual(r, 'fish')
+ r = yield obj.fn(2, 5)
+ self.assertEqual(r, 'chips')
+ obj.mock.assert_not_called()
|