summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/caches/descriptors.py56
-rw-r--r--synapse/util/caches/dictionary_cache.py57
-rw-r--r--synapse/util/caches/stream_change_cache.py15
3 files changed, 101 insertions, 27 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 48dcbafeef..cbdff86596 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -404,6 +404,7 @@ class CacheDescriptor(_CacheDescriptorBase):
 
         wrapped.invalidate_all = cache.invalidate_all
         wrapped.cache = cache
+        wrapped.num_args = self.num_args
 
         obj.__dict__[self.orig.__name__] = wrapped
 
@@ -451,8 +452,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
             )
 
     def __get__(self, obj, objtype=None):
-
-        cache = getattr(obj, self.cached_method_name).cache
+        cached_method = getattr(obj, self.cached_method_name)
+        cache = cached_method.cache
+        num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
@@ -469,12 +471,23 @@ class CacheListDescriptor(_CacheDescriptorBase):
             results = {}
             cached_defers = {}
             missing = []
-            for arg in list_args:
+
+            # If the cache takes a single arg then that is used as the key,
+            # otherwise a tuple is used.
+            if num_args == 1:
+                def cache_get(arg):
+                    return cache.get(arg, callback=invalidate_callback)
+            else:
                 key = list(keyargs)
-                key[self.list_pos] = arg
 
+                def cache_get(arg):
+                    key[self.list_pos] = arg
+                    return cache.get(tuple(key), callback=invalidate_callback)
+
+            for arg in list_args:
                 try:
-                    res = cache.get(tuple(key), callback=invalidate_callback)
+                    res = cache_get(arg)
+
                     if not isinstance(res, ObservableDeferred):
                         results[arg] = res
                     elif not res.has_succeeded():
@@ -505,17 +518,28 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
                     observer = ObservableDeferred(observer)
 
-                    key = list(keyargs)
-                    key[self.list_pos] = arg
-                    cache.set(
-                        tuple(key), observer,
-                        callback=invalidate_callback
-                    )
-
-                    def invalidate(f, key):
-                        cache.invalidate(key)
-                        return f
-                    observer.addErrback(invalidate, tuple(key))
+                    if num_args == 1:
+                        cache.set(
+                            arg, observer,
+                            callback=invalidate_callback
+                        )
+
+                        def invalidate(f, key):
+                            cache.invalidate(key)
+                            return f
+                        observer.addErrback(invalidate, arg)
+                    else:
+                        key = list(keyargs)
+                        key[self.list_pos] = arg
+                        cache.set(
+                            tuple(key), observer,
+                            callback=invalidate_callback
+                        )
+
+                        def invalidate(f, key):
+                            cache.invalidate(key)
+                            return f
+                        observer.addErrback(invalidate, tuple(key))
 
                     res = observer.observe()
                     res.addCallback(lambda r, arg: (arg, r), arg)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index cb6933c61c..d4105822b3 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -23,7 +23,17 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
+class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
+    """Returned when getting an entry from the cache
+
+    Attributes:
+        full (bool): Whether the cache has the full or dict or just some keys.
+            If not full then not all requested keys will necessarily be present
+            in `value`
+        known_absent (set): Keys that were looked up in the dict and were not
+            there.
+        value (dict): The full or partial dict value
+    """
     def __len__(self):
         return len(self.value)
 
@@ -58,21 +68,31 @@ class DictionaryCache(object):
                 )
 
     def get(self, key, dict_keys=None):
+        """Fetch an entry out of the cache
+
+        Args:
+            key
+            dict_key(list): If given a set of keys then return only those keys
+                that exist in the cache.
+
+        Returns:
+            DictionaryEntry
+        """
         entry = self.cache.get(key, self.sentinel)
         if entry is not self.sentinel:
             self.metrics.inc_hits()
 
             if dict_keys is None:
-                return DictionaryEntry(entry.full, dict(entry.value))
+                return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
             else:
-                return DictionaryEntry(entry.full, {
+                return DictionaryEntry(entry.full, entry.known_absent, {
                     k: entry.value[k]
                     for k in dict_keys
                     if k in entry.value
                 })
 
         self.metrics.inc_misses()
-        return DictionaryEntry(False, {})
+        return DictionaryEntry(False, set(), {})
 
     def invalidate(self, key):
         self.check_thread()
@@ -87,19 +107,34 @@ class DictionaryCache(object):
         self.sequence += 1
         self.cache.clear()
 
-    def update(self, sequence, key, value, full=False):
+    def update(self, sequence, key, value, full=False, known_absent=None):
+        """Updates the entry in the cache
+
+        Args:
+            sequence
+            key
+            value (dict): The value to update the cache with.
+            full (bool): Whether the given value is the full dict, or just a
+                partial subset there of. If not full then any existing entries
+                for the key will be updated.
+            known_absent (set): Set of keys that we know don't exist in the full
+                dict.
+        """
         self.check_thread()
         if self.sequence == sequence:
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
+            if known_absent is None:
+                known_absent = set()
             if full:
-                self._insert(key, value)
+                self._insert(key, value, known_absent)
             else:
-                self._update_or_insert(key, value)
+                self._update_or_insert(key, value, known_absent)
 
-    def _update_or_insert(self, key, value):
-        entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
+    def _update_or_insert(self, key, value, known_absent):
+        entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {}))
         entry.value.update(value)
+        entry.known_absent.update(known_absent)
 
-    def _insert(self, key, value):
-        self.cache[key] = DictionaryEntry(True, value)
+    def _insert(self, key, value, known_absent):
+        self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 70fe00ce0b..609625b322 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -89,6 +89,21 @@ class StreamChangeCache(object):
 
         return result
 
+    def has_any_entity_changed(self, stream_pos):
+        """Returns if any entity has changed
+        """
+        assert type(stream_pos) is int
+
+        if stream_pos >= self._earliest_known_stream_pos:
+            self.metrics.inc_hits()
+            keys = self._cache.keys()
+            i = keys.bisect_right(stream_pos)
+
+            return i < len(keys)
+        else:
+            self.metrics.inc_misses()
+            return True
+
     def get_all_entities_changed(self, stream_pos):
         """Returns all entites that have had new things since the given
         position. If the position is too old it will return None.