summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/__init__.py16
-rw-r--r--synapse/util/caches/descriptors.py34
-rw-r--r--synapse/util/caches/dictionary_cache.py34
-rw-r--r--synapse/util/caches/expiringcache.py5
-rw-r--r--synapse/util/caches/lrucache.py2
-rw-r--r--synapse/util/caches/stream_change_cache.py62
-rw-r--r--synapse/util/caches/treecache.py6
7 files changed, 89 insertions, 70 deletions
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 183faf75a1..7b065b195e 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -13,15 +13,25 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from prometheus_client.core import Gauge, REGISTRY, GaugeMetricFamily
-
 import os
 
-from six.moves import intern
 import six
+from six.moves import intern
+
+from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily
 
 CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
 
+
+def get_cache_factor_for(cache_name):
+    env_var = "SYNAPSE_CACHE_FACTOR_" + cache_name.upper()
+    factor = os.environ.get(env_var)
+    if factor:
+        return float(factor)
+
+    return CACHE_SIZE_FACTOR
+
+
 caches_by_name = {}
 collectors_by_name = {}
 
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 8a9dcb2fc2..f8a07df6b8 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,25 +13,26 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import functools
+import inspect
 import logging
+import threading
+from collections import namedtuple
 
+import six
+from six import itervalues, string_types
+
+from twisted.internet import defer
+
+from synapse.util import logcontext, unwrapFirstError
 from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError, logcontext
-from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
 from synapse.util.stringutils import to_ascii
 
 from . import register_cache
 
-from twisted.internet import defer
-from collections import namedtuple
-
-import functools
-import inspect
-import threading
-
-
 logger = logging.getLogger(__name__)
 
 
@@ -205,7 +206,7 @@ class Cache(object):
     def invalidate_all(self):
         self.check_thread()
         self.cache.clear()
-        for entry in self._pending_deferred_cache.itervalues():
+        for entry in itervalues(self._pending_deferred_cache):
             entry.invalidate()
         self._pending_deferred_cache.clear()
 
@@ -310,7 +311,7 @@ class CacheDescriptor(_CacheDescriptorBase):
             orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
             cache_context=cache_context)
 
-        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
+        max_entries = int(max_entries * get_cache_factor_for(orig.__name__))
 
         self.max_entries = max_entries
         self.tree = tree
@@ -392,9 +393,10 @@ class CacheDescriptor(_CacheDescriptorBase):
 
                 ret.addErrback(onErr)
 
-                # If our cache_key is a string, try to convert to ascii to save
-                # a bit of space in large caches
-                if isinstance(cache_key, basestring):
+                # If our cache_key is a string on py2, try to convert to ascii
+                # to save a bit of space in large caches. Py3 does this
+                # internally automatically.
+                if six.PY2 and isinstance(cache_key, string_types):
                     cache_key = to_ascii(cache_key)
 
                 result_d = ObservableDeferred(ret, consumeErrors=True)
@@ -565,7 +567,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
                     return results
 
                 return logcontext.make_deferred_yieldable(defer.gatherResults(
-                    cached_defers.values(),
+                    list(cached_defers.values()),
                     consumeErrors=True,
                 ).addCallback(update_results_dict).addErrback(
                     unwrapFirstError
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index bdc21e348f..6c0b5a4094 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -13,12 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.caches.lrucache import LruCache
-from collections import namedtuple
-from . import register_cache
-import threading
 import logging
+import threading
+from collections import namedtuple
 
+from synapse.util.caches.lrucache import LruCache
+
+from . import register_cache
 
 logger = logging.getLogger(__name__)
 
@@ -107,29 +108,28 @@ class DictionaryCache(object):
         self.sequence += 1
         self.cache.clear()
 
-    def update(self, sequence, key, value, full=False, known_absent=None):
+    def update(self, sequence, key, value, fetched_keys=None):
         """Updates the entry in the cache
 
         Args:
             sequence
-            key
-            value (dict): The value to update the cache with.
-            full (bool): Whether the given value is the full dict, or just a
-                partial subset there of. If not full then any existing entries
-                for the key will be updated.
-            known_absent (set): Set of keys that we know don't exist in the full
-                dict.
+            key (K)
+            value (dict[X,Y]): The value to update the cache with.
+            fetched_keys (None|set[X]): All of the dictionary keys which were
+                fetched from the database.
+
+                If None, this is the complete value for key K. Otherwise, it
+                is used to infer a list of keys which we know don't exist in
+                the full dict.
         """
         self.check_thread()
         if self.sequence == sequence:
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
-            if known_absent is None:
-                known_absent = set()
-            if full:
-                self._insert(key, value, known_absent)
+            if fetched_keys is None:
+                self._insert(key, value, set())
             else:
-                self._update_or_insert(key, value, known_absent)
+                self._update_or_insert(key, value, fetched_keys)
 
     def _update_or_insert(self, key, value, known_absent):
         # We pop and reinsert as we need to tell the cache the size may have
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index ff04c91955..4abca91f6d 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -13,11 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.caches import register_cache
-
-from collections import OrderedDict
 import logging
+from collections import OrderedDict
 
+from synapse.util.caches import register_cache
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 1c5a982094..b684f24e7b 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -14,8 +14,8 @@
 # limitations under the License.
 
 
-from functools import wraps
 import threading
+from functools import wraps
 
 from synapse.util.caches.treecache import TreeCache
 
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index a7fe0397fa..8637867c6d 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -13,12 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
-
-
-from blist import sorteddict
 import logging
 
+from sortedcontainers import SortedDict
+
+from synapse.util import caches
 
 logger = logging.getLogger(__name__)
 
@@ -32,16 +31,18 @@ class StreamChangeCache(object):
     entities that may have changed since that position. If position key is too
     old then the cache will simply return all given entities.
     """
-    def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}):
-        self._max_size = int(max_size * CACHE_SIZE_FACTOR)
+
+    def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache=None):
+        self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR)
         self._entity_to_key = {}
-        self._cache = sorteddict()
+        self._cache = SortedDict()
         self._earliest_known_stream_pos = current_stream_pos
         self.name = name
-        self.metrics = register_cache("cache", self.name, self._cache)
+        self.metrics = caches.register_cache("cache", self.name, self._cache)
 
-        for entity, stream_pos in prefilled_cache.items():
-            self.entity_has_changed(entity, stream_pos)
+        if prefilled_cache:
+            for entity, stream_pos in prefilled_cache.items():
+                self.entity_has_changed(entity, stream_pos)
 
     def has_entity_changed(self, entity, stream_pos):
         """Returns True if the entity may have been updated since stream_pos
@@ -65,22 +66,26 @@ class StreamChangeCache(object):
         return False
 
     def get_entities_changed(self, entities, stream_pos):
-        """Returns subset of entities that have had new things since the
-        given position. If the position is too old it will just return the given list.
+        """
+        Returns subset of entities that have had new things since the given
+        position.  Entities unknown to the cache will be returned.  If the
+        position is too old it will just return the given list.
         """
         assert type(stream_pos) is int
 
         if stream_pos >= self._earliest_known_stream_pos:
-            keys = self._cache.keys()
-            i = keys.bisect_right(stream_pos)
+            not_known_entities = set(entities) - set(self._entity_to_key)
 
-            result = set(
-                self._cache[k] for k in keys[i:]
-            ).intersection(entities)
+            result = (
+                {self._cache[k] for k in self._cache.islice(
+                    start=self._cache.bisect_right(stream_pos))}
+                .intersection(entities)
+                .union(not_known_entities)
+            )
 
             self.metrics.inc_hits()
         else:
-            result = entities
+            result = set(entities)
             self.metrics.inc_misses()
 
         return result
@@ -90,12 +95,13 @@ class StreamChangeCache(object):
         """
         assert type(stream_pos) is int
 
+        if not self._cache:
+            # If we have no cache, nothing can have changed.
+            return False
+
         if stream_pos >= self._earliest_known_stream_pos:
             self.metrics.inc_hits()
-            keys = self._cache.keys()
-            i = keys.bisect_right(stream_pos)
-
-            return i < len(keys)
+            return self._cache.bisect_right(stream_pos) < len(self._cache)
         else:
             self.metrics.inc_misses()
             return True
@@ -107,10 +113,8 @@ class StreamChangeCache(object):
         assert type(stream_pos) is int
 
         if stream_pos >= self._earliest_known_stream_pos:
-            keys = self._cache.keys()
-            i = keys.bisect_right(stream_pos)
-
-            return [self._cache[k] for k in keys[i:]]
+            return [self._cache[k] for k in self._cache.islice(
+                start=self._cache.bisect_right(stream_pos))]
         else:
             return None
 
@@ -129,8 +133,10 @@ class StreamChangeCache(object):
             self._entity_to_key[entity] = stream_pos
 
             while len(self._cache) > self._max_size:
-                k, r = self._cache.popitem()
-                self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
+                k, r = self._cache.popitem(0)
+                self._earliest_known_stream_pos = max(
+                    k, self._earliest_known_stream_pos,
+                )
                 self._entity_to_key.pop(r, None)
 
     def get_max_pos_of_last_change(self, entity):
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index fcc341a6b7..dd4c9e6067 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -1,3 +1,5 @@
+from six import itervalues
+
 SENTINEL = object()
 
 
@@ -49,7 +51,7 @@ class TreeCache(object):
         if popped is SENTINEL:
             return default
 
-        node_and_keys = zip(nodes, key)
+        node_and_keys = list(zip(nodes, key))
         node_and_keys.reverse()
         node_and_keys.append((self.root, None))
 
@@ -76,7 +78,7 @@ def iterate_tree_cache_entry(d):
     can contain dicts.
     """
     if isinstance(d, dict):
-        for value_d in d.itervalues():
+        for value_d in itervalues(d):
             for value in iterate_tree_cache_entry(value_d):
                 yield value
     else: