From 95f21c7a66b256e9e5330edd0f35a83b177c84a8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 22 Mar 2017 13:54:20 +0000 Subject: Fix caching of remote servers' signature keys The `@cached` decorator on `KeyStore._get_server_verify_key` was missing its `num_args` parameter, which meant that it was returning the wrong key for any server which had more than one recorded key. By way of a fix, change the default for `num_args` to be *all* arguments. To implement that, factor out a common base class for `CacheDescriptor` and `CacheListDescriptor`. --- synapse/util/caches/descriptors.py | 135 ++++++++++++++++++---------------- tests/storage/test_keys.py | 53 +++++++++++++ tests/util/caches/__init__.py | 14 ++++ tests/util/caches/test_descriptors.py | 86 ++++++++++++++++++++++ 4 files changed, 225 insertions(+), 63 deletions(-) create mode 100644 tests/storage/test_keys.py create mode 100644 tests/util/caches/__init__.py create mode 100644 tests/util/caches/test_descriptors.py diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 998de70d29..19595df422 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -189,7 +189,55 @@ class Cache(object): self.cache.clear() -class CacheDescriptor(object): +class _CacheDescriptorBase(object): + def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + + arg_spec = inspect.getargspec(orig) + all_args = arg_spec.args + + if "cache_context" in all_args: + if not cache_context: + raise ValueError( + "Cannot have a 'cache_context' arg without setting" + " cache_context=True" + ) + elif cache_context: + raise ValueError( + "Cannot have cache_context=True without having an arg" + " named `cache_context`" + ) + + if num_args is None: + num_args = len(all_args) - 1 + if cache_context: + num_args -= 1 + + if len(all_args) < num_args + 1: + raise Exception( + "Not enough explicit positional arguments to key off for %r: " + "got %i args, but wanted %i. (@cached cannot key off *args or " + "**kwargs)" + % (orig.__name__, len(all_args), num_args) + ) + + self.num_args = num_args + self.arg_names = all_args[1:num_args + 1] + + if "cache_context" in self.arg_names: + raise Exception( + "cache_context arg cannot be included among the cache keys" + ) + + self.add_cache_context = cache_context + + +class CacheDescriptor(_CacheDescriptorBase): """ A method decorator that applies a memoizing cache around the function. This caches deferreds, rather than the results themselves. Deferreds that @@ -217,52 +265,24 @@ class CacheDescriptor(object): r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) defer.returnValue(r1 + r2) + Args: + num_args (int): number of positional arguments (excluding ``self`` and + ``cache_context``) to use as cache keys. Defaults to all named + args of the function. """ - def __init__(self, orig, max_entries=1000, num_args=1, tree=False, + def __init__(self, orig, max_entries=1000, num_args=None, tree=False, inlineCallbacks=False, cache_context=False, iterable=False): - max_entries = int(max_entries * CACHE_SIZE_FACTOR) - self.orig = orig + super(CacheDescriptor, self).__init__( + orig, num_args=num_args, inlineCallbacks=inlineCallbacks, + cache_context=cache_context) - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig + max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.max_entries = max_entries - self.num_args = num_args self.tree = tree - self.iterable = iterable - all_args = inspect.getargspec(orig) - self.arg_names = all_args.args[1:num_args + 1] - - if "cache_context" in all_args.args: - if not cache_context: - raise ValueError( - "Cannot have a 'cache_context' arg without setting" - " cache_context=True" - ) - try: - self.arg_names.remove("cache_context") - except ValueError: - pass - elif cache_context: - raise ValueError( - "Cannot have cache_context=True without having an arg" - " named `cache_context`" - ) - - self.add_cache_context = cache_context - - if len(self.arg_names) < self.num_args: - raise Exception( - "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwargs)" - % (orig.__name__,) - ) - def __get__(self, obj, objtype=None): cache = Cache( name=self.orig.__name__, @@ -338,48 +358,36 @@ class CacheDescriptor(object): return wrapped -class CacheListDescriptor(object): +class CacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given a list of keys it looks in the cache to find any hits, then passes the list of missing keys to the wrapped fucntion. """ - def __init__(self, orig, cached_method_name, list_name, num_args=1, + def __init__(self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False): """ Args: orig (function) - method_name (str); The name of the chached method. + cached_method_name (str): The name of the chached method. list_name (str): Name of the argument which is the bulk lookup list - num_args (int) + num_args (int): number of positional arguments (excluding ``self``, + but including list_name) to use as cache keys. Defaults to all + named args of the function. inlineCallbacks (bool): Whether orig is a generator that should be wrapped by defer.inlineCallbacks """ - self.orig = orig + super(CacheListDescriptor, self).__init__( + orig, num_args=num_args, inlineCallbacks=inlineCallbacks) - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - - self.num_args = num_args self.list_name = list_name - self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) - self.cached_method_name = cached_method_name self.sentinel = object() - if len(self.arg_names) < self.num_args: - raise Exception( - "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" - % (orig.__name__,) - ) - if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." @@ -487,7 +495,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, +def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False): return lambda orig: CacheDescriptor( orig, @@ -499,8 +507,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, - iterable=False): +def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False, + cache_context=False, iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -512,7 +520,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex ) -def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -525,7 +533,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False) cache (Cache): The underlying cache to use. list_name (str): The name of the argument that is the list to use to do batch lookups in the cache. - num_args (int): Number of arguments to use as the key in the cache. + num_args (int): Number of arguments to use as the key in the cache + (including list_name). Defaults to all named parameters. inlineCallbacks (bool): Should the function be wrapped in an `defer.inlineCallbacks`? diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py new file mode 100644 index 0000000000..0be790d8f8 --- /dev/null +++ b/tests/storage/test_keys.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import signedjson.key +from twisted.internet import defer + +import tests.unittest +import tests.utils + + +class KeyStoreTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(KeyStoreTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.keys.KeyStore + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_get_server_verify_keys(self): + key1 = signedjson.key.decode_verify_key_base64( + "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" + ) + key2 = signedjson.key.decode_verify_key_base64( + "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" + ) + yield self.store.store_server_verify_key( + "server1", "from_server", 0, key1 + ) + yield self.store.store_server_verify_key( + "server1", "from_server", 0, key2 + ) + + res = yield self.store.get_server_verify_keys( + "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"]) + + self.assertEqual(len(res.keys()), 2) + self.assertEqual(res["ed25519:key1"].version, "key1") + self.assertEqual(res["ed25519:key2"].version, "key2") diff --git a/tests/util/caches/__init__.py b/tests/util/caches/__init__.py new file mode 100644 index 0000000000..451dae3b6c --- /dev/null +++ b/tests/util/caches/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py new file mode 100644 index 0000000000..419281054d --- /dev/null +++ b/tests/util/caches/test_descriptors.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import mock +from twisted.internet import defer +from synapse.util.caches import descriptors +from tests import unittest + + +class DescriptorTestCase(unittest.TestCase): + @defer.inlineCallbacks + def test_cache(self): + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, arg2): + return self.mock(arg1, arg2) + + obj = Cls() + + obj.mock.return_value = 'fish' + r = yield obj.fn(1, 2) + self.assertEqual(r, 'fish') + obj.mock.assert_called_once_with(1, 2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = 'chips' + r = yield obj.fn(1, 3) + self.assertEqual(r, 'chips') + obj.mock.assert_called_once_with(1, 3) + obj.mock.reset_mock() + + # the two values should now be cached + r = yield obj.fn(1, 2) + self.assertEqual(r, 'fish') + r = yield obj.fn(1, 3) + self.assertEqual(r, 'chips') + obj.mock.assert_not_called() + + @defer.inlineCallbacks + def test_cache_num_args(self): + """Only the first num_args arguments should matter to the cache""" + + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached(num_args=1) + def fn(self, arg1, arg2): + return self.mock(arg1, arg2) + + obj = Cls() + obj.mock.return_value = 'fish' + r = yield obj.fn(1, 2) + self.assertEqual(r, 'fish') + obj.mock.assert_called_once_with(1, 2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = 'chips' + r = yield obj.fn(2, 3) + self.assertEqual(r, 'chips') + obj.mock.assert_called_once_with(2, 3) + obj.mock.reset_mock() + + # the two values should now be cached; we should be able to vary + # the second argument and still get the cached result. + r = yield obj.fn(1, 4) + self.assertEqual(r, 'fish') + r = yield obj.fn(2, 5) + self.assertEqual(r, 'chips') + obj.mock.assert_not_called() -- cgit 1.4.1