diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 7736d14fb5..58a6d6a0ed 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -23,7 +23,7 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util import unwrapFirstError
-from synapse.util.expiringcache import ExpiringCache
+from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.events import FrozenEvent
import synapse.metrics
diff --git a/synapse/state.py b/synapse/state.py
index b5e5d7bbda..1fe4d066bd 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
-from synapse.util.expiringcache import ExpiringCache
+from synapse.util.caches.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e5441aafb2..1444767a52 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,27 +15,22 @@
import logging
from synapse.api.errors import StoreError
-from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
-from synapse.util.lrucache import LruCache
-from synapse.util.dictionary_cache import DictionaryCache
+from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.caches.descriptors import Cache
import synapse.metrics
from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
-from collections import namedtuple, OrderedDict
+from collections import namedtuple
-import functools
-import inspect
import sys
import time
import threading
-DEBUG_CACHES = False
logger = logging.getLogger(__name__)
@@ -51,330 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
-caches_by_name = {}
-cache_counter = metrics.register_cache(
- "cache",
- lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
- labels=["name"],
-)
-
-
-_CacheSentinel = object()
-
-
-class Cache(object):
-
- def __init__(self, name, max_entries=1000, keylen=1, lru=True):
- if lru:
- self.cache = LruCache(max_size=max_entries)
- self.max_entries = None
- else:
- self.cache = OrderedDict()
- self.max_entries = max_entries
-
- self.name = name
- self.keylen = keylen
- self.sequence = 0
- self.thread = None
- caches_by_name[name] = self.cache
-
- def check_thread(self):
- expected_thread = self.thread
- if expected_thread is None:
- self.thread = threading.current_thread()
- else:
- if expected_thread is not threading.current_thread():
- raise ValueError(
- "Cache objects can only be accessed from the main thread"
- )
-
- def get(self, key, default=_CacheSentinel):
- val = self.cache.get(key, _CacheSentinel)
- if val is not _CacheSentinel:
- cache_counter.inc_hits(self.name)
- return val
-
- cache_counter.inc_misses(self.name)
-
- if default is _CacheSentinel:
- raise KeyError()
- else:
- return default
-
- def update(self, sequence, key, value):
- self.check_thread()
- if self.sequence == sequence:
- # Only update the cache if the caches sequence number matches the
- # number that the cache had before the SELECT was started (SYN-369)
- self.prefill(key, value)
-
- def prefill(self, key, value):
- if self.max_entries is not None:
- while len(self.cache) >= self.max_entries:
- self.cache.popitem(last=False)
-
- self.cache[key] = value
-
- def invalidate(self, key):
- self.check_thread()
- if not isinstance(key, tuple):
- raise ValueError("keyargs must be a tuple.")
-
- # Increment the sequence number so that any SELECT statements that
- # raced with the INSERT don't update the cache (SYN-369)
- self.sequence += 1
- self.cache.pop(key, None)
-
- def invalidate_all(self):
- self.check_thread()
- self.sequence += 1
- self.cache.clear()
-
-
-class CacheDescriptor(object):
- """ A method decorator that applies a memoizing cache around the function.
-
- This caches deferreds, rather than the results themselves. Deferreds that
- fail are removed from the cache.
-
- The function is presumed to take zero or more arguments, which are used in
- a tuple as the key for the cache. Hits are served directly from the cache;
- misses use the function body to generate the value.
-
- The wrapped function has an additional member, a callable called
- "invalidate". This can be used to remove individual entries from the cache.
-
- The wrapped function has another additional callable, called "prefill",
- which can be used to insert values into the cache specifically, without
- calling the calculation function.
- """
- def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
- inlineCallbacks=False):
- self.orig = orig
-
- if inlineCallbacks:
- self.function_to_call = defer.inlineCallbacks(orig)
- else:
- self.function_to_call = orig
-
- self.max_entries = max_entries
- self.num_args = num_args
- self.lru = lru
-
- self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
-
- if len(self.arg_names) < self.num_args:
- raise Exception(
- "Not enough explicit positional arguments to key off of for %r."
- " (@cached cannot key off of *args or **kwars)"
- % (orig.__name__,)
- )
-
- self.cache = Cache(
- name=self.orig.__name__,
- max_entries=self.max_entries,
- keylen=self.num_args,
- lru=self.lru,
- )
-
- def __get__(self, obj, objtype=None):
-
- @functools.wraps(self.orig)
- def wrapped(*args, **kwargs):
- arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
- cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
- try:
- cached_result_d = self.cache.get(cache_key)
-
- observer = cached_result_d.observe()
- if DEBUG_CACHES:
- @defer.inlineCallbacks
- def check_result(cached_result):
- actual_result = yield self.function_to_call(obj, *args, **kwargs)
- if actual_result != cached_result:
- logger.error(
- "Stale cache entry %s%r: cached: %r, actual %r",
- self.orig.__name__, cache_key,
- cached_result, actual_result,
- )
- raise ValueError("Stale cache entry")
- defer.returnValue(cached_result)
- observer.addCallback(check_result)
-
- return observer
- except KeyError:
- # Get the sequence number of the cache before reading from the
- # database so that we can tell if the cache is invalidated
- # while the SELECT is executing (SYN-369)
- sequence = self.cache.sequence
-
- ret = defer.maybeDeferred(
- self.function_to_call,
- obj, *args, **kwargs
- )
-
- def onErr(f):
- self.cache.invalidate(cache_key)
- return f
-
- ret.addErrback(onErr)
-
- ret = ObservableDeferred(ret, consumeErrors=True)
- self.cache.update(sequence, cache_key, ret)
-
- return ret.observe()
-
- wrapped.invalidate = self.cache.invalidate
- wrapped.invalidate_all = self.cache.invalidate_all
- wrapped.prefill = self.cache.prefill
-
- obj.__dict__[self.orig.__name__] = wrapped
-
- return wrapped
-
-
-class CacheListDescriptor(object):
- """Wraps an existing cache to support bulk fetching of keys.
-
- Given a list of keys it looks in the cache to find any hits, then passes
- the list of missing keys to the wrapped fucntion.
- """
-
- def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
- """
- Args:
- orig (function)
- cache (Cache)
- list_name (str): Name of the argument which is the bulk lookup list
- num_args (int)
- inlineCallbacks (bool): Whether orig is a generator that should
- be wrapped by defer.inlineCallbacks
- """
- self.orig = orig
-
- if inlineCallbacks:
- self.function_to_call = defer.inlineCallbacks(orig)
- else:
- self.function_to_call = orig
-
- self.num_args = num_args
- self.list_name = list_name
-
- self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
- self.list_pos = self.arg_names.index(self.list_name)
-
- self.cache = cache
-
- self.sentinel = object()
-
- if len(self.arg_names) < self.num_args:
- raise Exception(
- "Not enough explicit positional arguments to key off of for %r."
- " (@cached cannot key off of *args or **kwars)"
- % (orig.__name__,)
- )
-
- if self.list_name not in self.arg_names:
- raise Exception(
- "Couldn't see arguments %r for %r."
- % (self.list_name, cache.name,)
- )
-
- def __get__(self, obj, objtype=None):
-
- @functools.wraps(self.orig)
- def wrapped(*args, **kwargs):
- arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
- keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
- list_args = arg_dict[self.list_name]
-
- # cached is a dict arg -> deferred, where deferred results in a
- # 2-tuple (`arg`, `result`)
- cached = {}
- missing = []
- for arg in list_args:
- key = list(keyargs)
- key[self.list_pos] = arg
-
- try:
- res = self.cache.get(tuple(key)).observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
- cached[arg] = res
- except KeyError:
- missing.append(arg)
-
- if missing:
- sequence = self.cache.sequence
- args_to_call = dict(arg_dict)
- args_to_call[self.list_name] = missing
-
- ret_d = defer.maybeDeferred(
- self.function_to_call,
- **args_to_call
- )
-
- ret_d = ObservableDeferred(ret_d)
-
- # We need to create deferreds for each arg in the list so that
- # we can insert the new deferred into the cache.
- for arg in missing:
- observer = ret_d.observe()
- observer.addCallback(lambda r, arg: r[arg], arg)
-
- observer = ObservableDeferred(observer)
-
- key = list(keyargs)
- key[self.list_pos] = arg
- self.cache.update(sequence, tuple(key), observer)
-
- def invalidate(f, key):
- self.cache.invalidate(key)
- return f
- observer.addErrback(invalidate, tuple(key))
-
- res = observer.observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
-
- cached[arg] = res
-
- return defer.gatherResults(
- cached.values(),
- consumeErrors=True,
- ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
-
- obj.__dict__[self.orig.__name__] = wrapped
-
- return wrapped
-
-
-def cached(max_entries=1000, num_args=1, lru=True):
- return lambda orig: CacheDescriptor(
- orig,
- max_entries=max_entries,
- num_args=num_args,
- lru=lru
- )
-
-
-def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
- return lambda orig: CacheDescriptor(
- orig,
- max_entries=max_entries,
- num_args=num_args,
- lru=lru,
- inlineCallbacks=True,
- )
-
-
-def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
- return lambda orig: CacheListDescriptor(
- orig,
- cache=cache,
- list_name=list_name,
- num_args=num_args,
- inlineCallbacks=inlineCallbacks,
- )
-
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index f3947bbe89..d92028ea43 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
from synapse.api.errors import SynapseError
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 910b6598a7..25cc84eb95 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -15,7 +15,8 @@
from twisted.internet import defer
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
from syutil.base64util import encode_base64
import logging
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 49b8e37cfd..ffd6daa880 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from _base import SQLBaseStore, cachedInlineCallbacks
+from _base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 576cf670cc..4f91a2b87c 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
from twisted.internet import defer
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 9b88ca7b39..5305b7e122 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cachedInlineCallbacks
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer
import logging
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index b79d6683ca..cac1a5657e 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cachedInlineCallbacks
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 4eaa088b36..aa446f94c6 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
class RegistrationStore(SQLBaseStore):
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index dd5bc2c8fb..5e07b7e0e5 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
-from ._base import SQLBaseStore, cachedInlineCallbacks
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
import collections
import logging
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 9f14f38f24..8eee2dfbcc 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -17,7 +17,8 @@ from twisted.internet import defer
from collections import namedtuple
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
from synapse.api.constants import Membership
from synapse.types import UserID
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index ea5fa9de7b..79c3b82d9f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cached, cachedInlineCallbacks, cachedList
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import (
+ cached, cachedInlineCallbacks, cachedList
+)
from twisted.internet import defer
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index b59fe81004..d7fe423f5a 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -35,7 +35,8 @@ what sort order was used:
from twisted.internet import defer
-from ._base import SQLBaseStore, cachedInlineCallbacks
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
from synapse.util.logutils import log_function
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 624da4a9dc..c8c7e6591a 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
from collections import namedtuple
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
new file mode 100644
index 0000000000..1a84d94cd9
--- /dev/null
+++ b/synapse/util/caches/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
new file mode 100644
index 0000000000..82dd09cf5e
--- /dev/null
+++ b/synapse/util/caches/descriptors.py
@@ -0,0 +1,359 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.util.async import ObservableDeferred
+from synapse.util import unwrapFirstError
+from synapse.util.caches.lrucache import LruCache
+import synapse.metrics
+
+from twisted.internet import defer
+
+from collections import OrderedDict
+
+import functools
+import inspect
+import threading
+
+logger = logging.getLogger(__name__)
+
+
+DEBUG_CACHES = False
+
+metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
+
+caches_by_name = {}
+cache_counter = metrics.register_cache(
+ "cache",
+ lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
+ labels=["name"],
+)
+
+
+_CacheSentinel = object()
+
+
+class Cache(object):
+
+ def __init__(self, name, max_entries=1000, keylen=1, lru=True):
+ if lru:
+ self.cache = LruCache(max_size=max_entries)
+ self.max_entries = None
+ else:
+ self.cache = OrderedDict()
+ self.max_entries = max_entries
+
+ self.name = name
+ self.keylen = keylen
+ self.sequence = 0
+ self.thread = None
+ caches_by_name[name] = self.cache
+
+ def check_thread(self):
+ expected_thread = self.thread
+ if expected_thread is None:
+ self.thread = threading.current_thread()
+ else:
+ if expected_thread is not threading.current_thread():
+ raise ValueError(
+ "Cache objects can only be accessed from the main thread"
+ )
+
+ def get(self, key, default=_CacheSentinel):
+ val = self.cache.get(key, _CacheSentinel)
+ if val is not _CacheSentinel:
+ cache_counter.inc_hits(self.name)
+ return val
+
+ cache_counter.inc_misses(self.name)
+
+ if default is _CacheSentinel:
+ raise KeyError()
+ else:
+ return default
+
+ def update(self, sequence, key, value):
+ self.check_thread()
+ if self.sequence == sequence:
+ # Only update the cache if the caches sequence number matches the
+ # number that the cache had before the SELECT was started (SYN-369)
+ self.prefill(key, value)
+
+ def prefill(self, key, value):
+ if self.max_entries is not None:
+ while len(self.cache) >= self.max_entries:
+ self.cache.popitem(last=False)
+
+ self.cache[key] = value
+
+ def invalidate(self, key):
+ self.check_thread()
+ if not isinstance(key, tuple):
+ raise ValueError("keyargs must be a tuple.")
+
+ # Increment the sequence number so that any SELECT statements that
+ # raced with the INSERT don't update the cache (SYN-369)
+ self.sequence += 1
+ self.cache.pop(key, None)
+
+ def invalidate_all(self):
+ self.check_thread()
+ self.sequence += 1
+ self.cache.clear()
+
+
+class CacheDescriptor(object):
+ """ A method decorator that applies a memoizing cache around the function.
+
+ This caches deferreds, rather than the results themselves. Deferreds that
+ fail are removed from the cache.
+
+ The function is presumed to take zero or more arguments, which are used in
+ a tuple as the key for the cache. Hits are served directly from the cache;
+ misses use the function body to generate the value.
+
+ The wrapped function has an additional member, a callable called
+ "invalidate". This can be used to remove individual entries from the cache.
+
+ The wrapped function has another additional callable, called "prefill",
+ which can be used to insert values into the cache specifically, without
+ calling the calculation function.
+ """
+ def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
+ inlineCallbacks=False):
+ self.orig = orig
+
+ if inlineCallbacks:
+ self.function_to_call = defer.inlineCallbacks(orig)
+ else:
+ self.function_to_call = orig
+
+ self.max_entries = max_entries
+ self.num_args = num_args
+ self.lru = lru
+
+ self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+
+ if len(self.arg_names) < self.num_args:
+ raise Exception(
+ "Not enough explicit positional arguments to key off of for %r."
+ " (@cached cannot key off of *args or **kwars)"
+ % (orig.__name__,)
+ )
+
+ self.cache = Cache(
+ name=self.orig.__name__,
+ max_entries=self.max_entries,
+ keylen=self.num_args,
+ lru=self.lru,
+ )
+
+ def __get__(self, obj, objtype=None):
+
+ @functools.wraps(self.orig)
+ def wrapped(*args, **kwargs):
+ arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
+ cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+ try:
+ cached_result_d = self.cache.get(cache_key)
+
+ observer = cached_result_d.observe()
+ if DEBUG_CACHES:
+ @defer.inlineCallbacks
+ def check_result(cached_result):
+ actual_result = yield self.function_to_call(obj, *args, **kwargs)
+ if actual_result != cached_result:
+ logger.error(
+ "Stale cache entry %s%r: cached: %r, actual %r",
+ self.orig.__name__, cache_key,
+ cached_result, actual_result,
+ )
+ raise ValueError("Stale cache entry")
+ defer.returnValue(cached_result)
+ observer.addCallback(check_result)
+
+ return observer
+ except KeyError:
+ # Get the sequence number of the cache before reading from the
+ # database so that we can tell if the cache is invalidated
+ # while the SELECT is executing (SYN-369)
+ sequence = self.cache.sequence
+
+ ret = defer.maybeDeferred(
+ self.function_to_call,
+ obj, *args, **kwargs
+ )
+
+ def onErr(f):
+ self.cache.invalidate(cache_key)
+ return f
+
+ ret.addErrback(onErr)
+
+ ret = ObservableDeferred(ret, consumeErrors=True)
+ self.cache.update(sequence, cache_key, ret)
+
+ return ret.observe()
+
+ wrapped.invalidate = self.cache.invalidate
+ wrapped.invalidate_all = self.cache.invalidate_all
+ wrapped.prefill = self.cache.prefill
+
+ obj.__dict__[self.orig.__name__] = wrapped
+
+ return wrapped
+
+
+class CacheListDescriptor(object):
+ """Wraps an existing cache to support bulk fetching of keys.
+
+ Given a list of keys it looks in the cache to find any hits, then passes
+ the list of missing keys to the wrapped fucntion.
+ """
+
+ def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
+ """
+ Args:
+ orig (function)
+ cache (Cache)
+ list_name (str): Name of the argument which is the bulk lookup list
+ num_args (int)
+ inlineCallbacks (bool): Whether orig is a generator that should
+ be wrapped by defer.inlineCallbacks
+ """
+ self.orig = orig
+
+ if inlineCallbacks:
+ self.function_to_call = defer.inlineCallbacks(orig)
+ else:
+ self.function_to_call = orig
+
+ self.num_args = num_args
+ self.list_name = list_name
+
+ self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+ self.list_pos = self.arg_names.index(self.list_name)
+
+ self.cache = cache
+
+ self.sentinel = object()
+
+ if len(self.arg_names) < self.num_args:
+ raise Exception(
+ "Not enough explicit positional arguments to key off of for %r."
+ " (@cached cannot key off of *args or **kwars)"
+ % (orig.__name__,)
+ )
+
+ if self.list_name not in self.arg_names:
+ raise Exception(
+ "Couldn't see arguments %r for %r."
+ % (self.list_name, cache.name,)
+ )
+
+ def __get__(self, obj, objtype=None):
+
+ @functools.wraps(self.orig)
+ def wrapped(*args, **kwargs):
+ arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
+ keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
+ list_args = arg_dict[self.list_name]
+
+ # cached is a dict arg -> deferred, where deferred results in a
+ # 2-tuple (`arg`, `result`)
+ cached = {}
+ missing = []
+ for arg in list_args:
+ key = list(keyargs)
+ key[self.list_pos] = arg
+
+ try:
+ res = self.cache.get(tuple(key)).observe()
+ res.addCallback(lambda r, arg: (arg, r), arg)
+ cached[arg] = res
+ except KeyError:
+ missing.append(arg)
+
+ if missing:
+ sequence = self.cache.sequence
+ args_to_call = dict(arg_dict)
+ args_to_call[self.list_name] = missing
+
+ ret_d = defer.maybeDeferred(
+ self.function_to_call,
+ **args_to_call
+ )
+
+ ret_d = ObservableDeferred(ret_d)
+
+ # We need to create deferreds for each arg in the list so that
+ # we can insert the new deferred into the cache.
+ for arg in missing:
+ observer = ret_d.observe()
+ observer.addCallback(lambda r, arg: r[arg], arg)
+
+ observer = ObservableDeferred(observer)
+
+ key = list(keyargs)
+ key[self.list_pos] = arg
+ self.cache.update(sequence, tuple(key), observer)
+
+ def invalidate(f, key):
+ self.cache.invalidate(key)
+ return f
+ observer.addErrback(invalidate, tuple(key))
+
+ res = observer.observe()
+ res.addCallback(lambda r, arg: (arg, r), arg)
+
+ cached[arg] = res
+
+ return defer.gatherResults(
+ cached.values(),
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
+
+ obj.__dict__[self.orig.__name__] = wrapped
+
+ return wrapped
+
+
+def cached(max_entries=1000, num_args=1, lru=True):
+ return lambda orig: CacheDescriptor(
+ orig,
+ max_entries=max_entries,
+ num_args=num_args,
+ lru=lru
+ )
+
+
+def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
+ return lambda orig: CacheDescriptor(
+ orig,
+ max_entries=max_entries,
+ num_args=num_args,
+ lru=lru,
+ inlineCallbacks=True,
+ )
+
+
+def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
+ return lambda orig: CacheListDescriptor(
+ orig,
+ cache=cache,
+ list_name=list_name,
+ num_args=num_args,
+ inlineCallbacks=inlineCallbacks,
+ )
diff --git a/synapse/util/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index c7564cdf0d..26d464f4f7 100644
--- a/synapse/util/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.lrucache import LruCache
+from synapse.util.caches.lrucache import LruCache
from collections import namedtuple
import threading
import logging
diff --git a/synapse/util/expiringcache.py b/synapse/util/caches/expiringcache.py
index 06d1eea01b..06d1eea01b 100644
--- a/synapse/util/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
diff --git a/synapse/util/lrucache.py b/synapse/util/caches/lrucache.py
index cacd7e45fa..cacd7e45fa 100644
--- a/synapse/util/lrucache.py
+++ b/synapse/util/caches/lrucache.py
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index abee2f631d..e72cace8ff 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.util.async import ObservableDeferred
-from synapse.storage._base import Cache, cached
+from synapse.util.caches.descriptors import Cache, cached
class CacheTestCase(unittest.TestCase):
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index 79bc1225d6..54ff26cd97 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -17,7 +17,7 @@
from twisted.internet import defer
from tests import unittest
-from synapse.util.dictionary_cache import DictionaryCache
+from synapse.util.caches.dictionary_cache import DictionaryCache
class DictCacheTestCase(unittest.TestCase):
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index ab934bf928..fc5a904323 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -16,7 +16,7 @@
from .. import unittest
-from synapse.util.lrucache import LruCache
+from synapse.util.caches.lrucache import LruCache
class LruCacheTestCase(unittest.TestCase):
@@ -52,5 +52,3 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1
self.assertEquals(cache.pop("key"), 1)
self.assertEquals(cache.pop("key"), None)
-
-
|