summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py42
1 files changed, 33 insertions, 9 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 807e147657..aa182eeac7 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,6 +18,7 @@ from synapse.util.async import ObservableDeferred
 from synapse.util import unwrapFirstError, logcontext
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.stringutils import to_ascii
 
 from . import register_cache
 
@@ -163,10 +164,6 @@ class Cache(object):
 
     def invalidate(self, key):
         self.check_thread()
-        if not isinstance(key, tuple):
-            raise TypeError(
-                "The cache key must be a tuple not %r" % (type(key),)
-            )
 
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
@@ -312,7 +309,7 @@ class CacheDescriptor(_CacheDescriptorBase):
             iterable=self.iterable,
         )
 
-        def get_cache_key(args, kwargs):
+        def get_cache_key_gen(args, kwargs):
             """Given some args/kwargs return a generator that resolves into
             the cache_key.
 
@@ -330,13 +327,29 @@ class CacheDescriptor(_CacheDescriptorBase):
                 else:
                     yield self.arg_defaults[nm]
 
+        # By default our cache key is a tuple, but if there is only one item
+        # then don't bother wrapping in a tuple.  This is to save memory.
+        if self.num_args == 1:
+            nm = self.arg_names[0]
+
+            def get_cache_key(args, kwargs):
+                if nm in kwargs:
+                    return kwargs[nm]
+                elif len(args):
+                    return args[0]
+                else:
+                    return self.arg_defaults[nm]
+        else:
+            def get_cache_key(args, kwargs):
+                return tuple(get_cache_key_gen(args, kwargs))
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
-            cache_key = tuple(get_cache_key(args, kwargs))
+            cache_key = get_cache_key(args, kwargs)
 
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
@@ -363,6 +376,11 @@ class CacheDescriptor(_CacheDescriptorBase):
 
                 ret.addErrback(onErr)
 
+                # If our cache_key is a string, try to convert to ascii to save
+                # a bit of space in large caches
+                if isinstance(cache_key, basestring):
+                    cache_key = to_ascii(cache_key)
+
                 result_d = ObservableDeferred(ret, consumeErrors=True)
                 cache.set(cache_key, result_d, callback=invalidate_callback)
                 observer = result_d.observe()
@@ -372,10 +390,16 @@ class CacheDescriptor(_CacheDescriptorBase):
             else:
                 return observer
 
-        wrapped.invalidate = cache.invalidate
+        if self.num_args == 1:
+            wrapped.invalidate = lambda key: cache.invalidate(key[0])
+            wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
+        else:
+            wrapped.invalidate = cache.invalidate
+            wrapped.invalidate_all = cache.invalidate_all
+            wrapped.invalidate_many = cache.invalidate_many
+            wrapped.prefill = cache.prefill
+
         wrapped.invalidate_all = cache.invalidate_all
-        wrapped.invalidate_many = cache.invalidate_many
-        wrapped.prefill = cache.prefill
         wrapped.cache = cache
 
         obj.__dict__[self.orig.__name__] = wrapped