diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 98a5a26ac5..2a2360ab5d 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class DeferredTimedOutError(SynapseError):
def __init__(self):
- super(SynapseError, self).__init__(504, "Timed out")
+ super(DeferredTimedOutError, self).__init__(504, "Timed out")
def unwrapFirstError(failure):
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 35380bf8ed..1453faf0ef 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -89,6 +89,11 @@ class ObservableDeferred(object):
deferred.addCallbacks(callback, errback)
def observe(self):
+ """Observe the underlying deferred.
+
+ Can return either a deferred if the underlying deferred is still pending
+ (or has failed), or the actual value. Callers may need to use maybeDeferred.
+ """
if not self._result:
d = defer.Deferred()
@@ -101,7 +106,7 @@ class ObservableDeferred(object):
return d
else:
success, res = self._result
- return defer.succeed(res) if success else defer.fail(res)
+ return res if success else defer.fail(res)
def observers(self):
return self._observers
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 8a7774a88e..4a83c46d98 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -14,13 +14,10 @@
# limitations under the License.
import synapse.metrics
-from lrucache import LruCache
import os
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-DEBUG_CACHES = False
-
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
caches_by_name = {}
@@ -40,10 +37,6 @@ def register_cache(name, cache):
)
-_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR))
-_stirng_cache_metrics = register_cache("string_cache", _string_cache)
-
-
KNOWN_KEYS = {
key: key for key in
(
@@ -67,14 +60,16 @@ KNOWN_KEYS = {
def intern_string(string):
- """Takes a (potentially) unicode string and interns using custom cache
+ """Takes a (potentially) unicode string and interns it if it's ascii
"""
- new_str = _string_cache.setdefault(string, string)
- if new_str is string:
- _stirng_cache_metrics.inc_hits()
- else:
- _stirng_cache_metrics.inc_misses()
- return new_str
+ if string is None:
+ return None
+
+ try:
+ string = string.encode("ascii")
+ return intern(string)
+ except UnicodeEncodeError:
+ return string
def intern_dict(dictionary):
@@ -87,13 +82,9 @@ def intern_dict(dictionary):
def _intern_known_values(key, value):
- intern_str_keys = ("event_id", "room_id")
- intern_unicode_keys = ("sender", "user_id", "type", "state_key")
-
- if key in intern_str_keys:
- return intern(value.encode('ascii'))
+ intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",)
- if key in intern_unicode_keys:
+ if key in intern_keys:
return intern_string(value)
return value
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5c30ed235d..48dcbafeef 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,8 +18,9 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.stringutils import to_ascii
-from . import DEBUG_CACHES, register_cache
+from . import register_cache
from twisted.internet import defer
from collections import namedtuple
@@ -76,7 +77,7 @@ class Cache(object):
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type,
- size_callback=(lambda d: len(d.result)) if iterable else None,
+ size_callback=(lambda d: len(d)) if iterable else None,
)
self.name = name
@@ -95,13 +96,26 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
- def get(self, key, default=_CacheSentinel, callback=None):
+ def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
+ """Looks the key up in the caches.
+
+ Args:
+ key(tuple)
+ default: What is returned if key is not in the caches. If not
+ specified then function throws KeyError instead
+ callback(fn): Gets called when the entry in the cache is invalidated
+ update_metrics (bool): whether to update the cache hit rate metrics
+
+ Returns:
+ Either a Deferred or the raw result
+ """
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
if val.sequence == self.sequence:
val.callbacks.update(callbacks)
- self.metrics.inc_hits()
+ if update_metrics:
+ self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
@@ -109,7 +123,8 @@ class Cache(object):
self.metrics.inc_hits()
return val
- self.metrics.inc_misses()
+ if update_metrics:
+ self.metrics.inc_misses()
if default is _CacheSentinel:
raise KeyError()
@@ -137,7 +152,7 @@ class Cache(object):
if self.sequence == entry.sequence:
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
- self.cache.set(key, entry.deferred, entry.callbacks)
+ self.cache.set(key, result, entry.callbacks)
else:
entry.invalidate()
else:
@@ -152,10 +167,6 @@ class Cache(object):
def invalidate(self, key):
self.check_thread()
- if not isinstance(key, tuple):
- raise TypeError(
- "The cache key must be a tuple not %r" % (type(key),)
- )
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
@@ -224,8 +235,20 @@ class _CacheDescriptorBase(object):
)
self.num_args = num_args
+
+ # list of the names of the args used as the cache key
self.arg_names = all_args[1:num_args + 1]
+ # self.arg_defaults is a map of arg name to its default value for each
+ # argument that has a default value
+ if arg_spec.defaults:
+ self.arg_defaults = dict(zip(
+ all_args[-len(arg_spec.defaults):],
+ arg_spec.defaults
+ ))
+ else:
+ self.arg_defaults = {}
+
if "cache_context" in self.arg_names:
raise Exception(
"cache_context arg cannot be included among the cache keys"
@@ -289,18 +312,47 @@ class CacheDescriptor(_CacheDescriptorBase):
iterable=self.iterable,
)
+ def get_cache_key_gen(args, kwargs):
+ """Given some args/kwargs return a generator that resolves into
+ the cache_key.
+
+ We loop through each arg name, looking up if its in the `kwargs`,
+ otherwise using the next argument in `args`. If there are no more
+ args then we try looking the arg name up in the defaults
+ """
+ pos = 0
+ for nm in self.arg_names:
+ if nm in kwargs:
+ yield kwargs[nm]
+ elif pos < len(args):
+ yield args[pos]
+ pos += 1
+ else:
+ yield self.arg_defaults[nm]
+
+ # By default our cache key is a tuple, but if there is only one item
+ # then don't bother wrapping in a tuple. This is to save memory.
+ if self.num_args == 1:
+ nm = self.arg_names[0]
+
+ def get_cache_key(args, kwargs):
+ if nm in kwargs:
+ return kwargs[nm]
+ elif len(args):
+ return args[0]
+ else:
+ return self.arg_defaults[nm]
+ else:
+ def get_cache_key(args, kwargs):
+ return tuple(get_cache_key_gen(args, kwargs))
+
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
- # Add temp cache_context so inspect.getcallargs doesn't explode
- if self.add_cache_context:
- kwargs["cache_context"] = None
-
- arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
- cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+ cache_key = get_cache_key(args, kwargs)
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
@@ -310,20 +362,10 @@ class CacheDescriptor(_CacheDescriptorBase):
try:
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
- observer = cached_result_d.observe()
- if DEBUG_CACHES:
- @defer.inlineCallbacks
- def check_result(cached_result):
- actual_result = yield self.function_to_call(obj, *args, **kwargs)
- if actual_result != cached_result:
- logger.error(
- "Stale cache entry %s%r: cached: %r, actual %r",
- self.orig.__name__, cache_key,
- cached_result, actual_result,
- )
- raise ValueError("Stale cache entry")
- defer.returnValue(cached_result)
- observer.addCallback(check_result)
+ if isinstance(cached_result_d, ObservableDeferred):
+ observer = cached_result_d.observe()
+ else:
+ observer = cached_result_d
except KeyError:
ret = defer.maybeDeferred(
@@ -337,16 +379,30 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
+ # If our cache_key is a string, try to convert to ascii to save
+ # a bit of space in large caches
+ if isinstance(cache_key, basestring):
+ cache_key = to_ascii(cache_key)
+
result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, result_d, callback=invalidate_callback)
observer = result_d.observe()
- return logcontext.make_deferred_yieldable(observer)
+ if isinstance(observer, defer.Deferred):
+ return logcontext.make_deferred_yieldable(observer)
+ else:
+ return observer
+
+ if self.num_args == 1:
+ wrapped.invalidate = lambda key: cache.invalidate(key[0])
+ wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
+ else:
+ wrapped.invalidate = cache.invalidate
+ wrapped.invalidate_all = cache.invalidate_all
+ wrapped.invalidate_many = cache.invalidate_many
+ wrapped.prefill = cache.prefill
- wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
- wrapped.invalidate_many = cache.invalidate_many
- wrapped.prefill = cache.prefill
wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped
@@ -419,7 +475,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
try:
res = cache.get(tuple(key), callback=invalidate_callback)
- if not res.has_succeeded():
+ if not isinstance(res, ObservableDeferred):
+ results[arg] = res
+ elif not res.has_succeeded():
res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached_defers[arg] = res
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 857afee7cb..990216145e 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -334,12 +334,8 @@ def preserve_fn(f):
LoggingContext.set_current_context(LoggingContext.sentinel)
return result
- # XXX: why is this here rather than inside g? surely we want to preserve
- # the context from the time the function was called, not when it was
- # wrapped?
- current = LoggingContext.current_context()
-
def g(*args, **kwargs):
+ current = LoggingContext.current_context()
res = f(*args, **kwargs)
if isinstance(res, defer.Deferred) and not res.called:
# The function will have reset the context before returning, so
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index a100f151d4..95a6168e16 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -40,3 +40,17 @@ def is_ascii(s):
return False
else:
return True
+
+
+def to_ascii(s):
+ """Converts a string to ascii if it is ascii, otherwise leave it alone.
+
+ If given None then will return None.
+ """
+ if s is None:
+ return None
+
+ try:
+ return s.encode("ascii")
+ except UnicodeEncodeError:
+ return s
|