summary refs log tree commit diff
path: root/synapse/util/caches/lrucache.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/lrucache.py')
-rw-r--r--synapse/util/caches/lrucache.py125
1 files changed, 106 insertions, 19 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index a21d34fcb4..1be675e014 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -17,8 +17,10 @@ from functools import wraps
 from typing import (
     Any,
     Callable,
+    Collection,
     Generic,
     Iterable,
+    List,
     Optional,
     Type,
     TypeVar,
@@ -30,9 +32,36 @@ from typing import (
 from typing_extensions import Literal
 
 from synapse.config import cache as cache_config
+from synapse.util import caches
 from synapse.util.caches import CacheMetric, register_cache
 from synapse.util.caches.treecache import TreeCache
 
+try:
+    from pympler.asizeof import Asizer
+
+    def _get_size_of(val: Any, *, recurse=True) -> int:
+        """Get an estimate of the size in bytes of the object.
+
+        Args:
+            val: The object to size.
+            recurse: If true will include referenced values in the size,
+                otherwise only sizes the given object.
+        """
+        # Ignore singleton values when calculating memory usage.
+        if val in ((), None, ""):
+            return 0
+
+        sizer = Asizer()
+        sizer.exclude_refs((), None, "")
+        return sizer.asizeof(val, limit=100 if recurse else 0)
+
+
+except ImportError:
+
+    def _get_size_of(val: Any, *, recurse=True) -> int:
+        return 0
+
+
 # Function type: the type used for invalidation callbacks
 FT = TypeVar("FT", bound=Callable[..., Any])
 
@@ -54,16 +83,69 @@ def enumerate_leaves(node, depth):
 
 
 class _Node:
-    __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
+    __slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"]
 
     def __init__(
-        self, prev_node, next_node, key, value, callbacks: Optional[set] = None
+        self,
+        prev_node,
+        next_node,
+        key,
+        value,
+        callbacks: Collection[Callable[[], None]] = (),
     ):
         self.prev_node = prev_node
         self.next_node = next_node
         self.key = key
         self.value = value
-        self.callbacks = callbacks or set()
+
+        # Set of callbacks to run when the node gets deleted. We store as a list
+        # rather than a set to keep memory usage down (and since we expect few
+        # entries per node, the performance of checking for duplication in a
+        # list vs using a set is negligible).
+        #
+        # Note that we store this as an optional list to keep the memory
+        # footprint down. Storing `None` is free as its a singleton, while empty
+        # lists are 56 bytes (and empty sets are 216 bytes, if we did the naive
+        # thing and used sets).
+        self.callbacks = None  # type: Optional[List[Callable[[], None]]]
+
+        self.add_callbacks(callbacks)
+
+        self.memory = 0
+        if caches.TRACK_MEMORY_USAGE:
+            self.memory = (
+                _get_size_of(key)
+                + _get_size_of(value)
+                + _get_size_of(self.callbacks, recurse=False)
+                + _get_size_of(self, recurse=False)
+            )
+            self.memory += _get_size_of(self.memory, recurse=False)
+
+    def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
+        """Add to stored list of callbacks, removing duplicates."""
+
+        if not callbacks:
+            return
+
+        if not self.callbacks:
+            self.callbacks = []
+
+        for callback in callbacks:
+            if callback not in self.callbacks:
+                self.callbacks.append(callback)
+
+    def run_and_clear_callbacks(self) -> None:
+        """Run all callbacks and clear the stored list of callbacks. Used when
+        the node is being deleted.
+        """
+
+        if not self.callbacks:
+            return
+
+        for callback in self.callbacks:
+            callback()
+
+        self.callbacks = None
 
 
 class LruCache(Generic[KT, VT]):
@@ -177,10 +259,10 @@ class LruCache(Generic[KT, VT]):
 
         self.len = synchronized(cache_len)
 
-        def add_node(key, value, callbacks: Optional[set] = None):
+        def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
             prev_node = list_root
             next_node = prev_node.next_node
-            node = _Node(prev_node, next_node, key, value, callbacks or set())
+            node = _Node(prev_node, next_node, key, value, callbacks)
             prev_node.next_node = node
             next_node.prev_node = node
             cache[key] = node
@@ -188,6 +270,9 @@ class LruCache(Generic[KT, VT]):
             if size_callback:
                 cached_cache_len[0] += size_callback(node.value)
 
+            if caches.TRACK_MEMORY_USAGE and metrics:
+                metrics.inc_memory_usage(node.memory)
+
         def move_node_to_front(node):
             prev_node = node.prev_node
             next_node = node.next_node
@@ -211,16 +296,18 @@ class LruCache(Generic[KT, VT]):
                 deleted_len = size_callback(node.value)
                 cached_cache_len[0] -= deleted_len
 
-            for cb in node.callbacks:
-                cb()
-            node.callbacks.clear()
+            node.run_and_clear_callbacks()
+
+            if caches.TRACK_MEMORY_USAGE and metrics:
+                metrics.dec_memory_usage(node.memory)
+
             return deleted_len
 
         @overload
         def cache_get(
             key: KT,
             default: Literal[None] = None,
-            callbacks: Iterable[Callable[[], None]] = ...,
+            callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
         ) -> Optional[VT]:
             ...
@@ -229,7 +316,7 @@ class LruCache(Generic[KT, VT]):
         def cache_get(
             key: KT,
             default: T,
-            callbacks: Iterable[Callable[[], None]] = ...,
+            callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
         ) -> Union[T, VT]:
             ...
@@ -238,13 +325,13 @@ class LruCache(Generic[KT, VT]):
         def cache_get(
             key: KT,
             default: Optional[T] = None,
-            callbacks: Iterable[Callable[[], None]] = (),
+            callbacks: Collection[Callable[[], None]] = (),
             update_metrics: bool = True,
         ):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
-                node.callbacks.update(callbacks)
+                node.add_callbacks(callbacks)
                 if update_metrics and metrics:
                     metrics.inc_hits()
                 return node.value
@@ -260,10 +347,8 @@ class LruCache(Generic[KT, VT]):
                 # We sometimes store large objects, e.g. dicts, which cause
                 # the inequality check to take a long time. So let's only do
                 # the check if we have some callbacks to call.
-                if node.callbacks and value != node.value:
-                    for cb in node.callbacks:
-                        cb()
-                    node.callbacks.clear()
+                if value != node.value:
+                    node.run_and_clear_callbacks()
 
                 # We don't bother to protect this by value != node.value as
                 # generally size_callback will be cheap compared with equality
@@ -273,7 +358,7 @@ class LruCache(Generic[KT, VT]):
                     cached_cache_len[0] -= size_callback(node.value)
                     cached_cache_len[0] += size_callback(value)
 
-                node.callbacks.update(callbacks)
+                node.add_callbacks(callbacks)
 
                 move_node_to_front(node)
                 node.value = value
@@ -326,12 +411,14 @@ class LruCache(Generic[KT, VT]):
             list_root.next_node = list_root
             list_root.prev_node = list_root
             for node in cache.values():
-                for cb in node.callbacks:
-                    cb()
+                node.run_and_clear_callbacks()
             cache.clear()
             if size_callback:
                 cached_cache_len[0] = 0
 
+            if caches.TRACK_MEMORY_USAGE and metrics:
+                metrics.clear_memory_usage()
+
         @synchronized
         def cache_contains(key: KT) -> bool:
             return key in cache