summary refs log tree commit diff
path: root/synapse/util/caches/lrucache.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/util/caches/lrucache.py77
1 files changed, 59 insertions, 18 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index a21d34fcb4..10b0ec6b75 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -17,8 +17,10 @@ from functools import wraps
 from typing import (
     Any,
     Callable,
+    Collection,
     Generic,
     Iterable,
+    List,
     Optional,
     Type,
     TypeVar,
@@ -57,13 +59,56 @@ class _Node:
     __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
 
     def __init__(
-        self, prev_node, next_node, key, value, callbacks: Optional[set] = None
+        self,
+        prev_node,
+        next_node,
+        key,
+        value,
+        callbacks: Collection[Callable[[], None]] = (),
     ):
         self.prev_node = prev_node
         self.next_node = next_node
         self.key = key
         self.value = value
-        self.callbacks = callbacks or set()
+
+        # Set of callbacks to run when the node gets deleted. We store as a list
+        # rather than a set to keep memory usage down (and since we expect few
+        # entries per node, the performance of checking for duplication in a
+        # list vs using a set is negligible).
+        #
+        # Note that we store this as an optional list to keep the memory
+        # footprint down. Storing `None` is free as its a singleton, while empty
+        # lists are 56 bytes (and empty sets are 216 bytes, if we did the naive
+        # thing and used sets).
+        self.callbacks = None  # type: Optional[List[Callable[[], None]]]
+
+        self.add_callbacks(callbacks)
+
+    def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
+        """Add to stored list of callbacks, removing duplicates."""
+
+        if not callbacks:
+            return
+
+        if not self.callbacks:
+            self.callbacks = []
+
+        for callback in callbacks:
+            if callback not in self.callbacks:
+                self.callbacks.append(callback)
+
+    def run_and_clear_callbacks(self) -> None:
+        """Run all callbacks and clear the stored list of callbacks. Used when
+        the node is being deleted.
+        """
+
+        if not self.callbacks:
+            return
+
+        for callback in self.callbacks:
+            callback()
+
+        self.callbacks = None
 
 
 class LruCache(Generic[KT, VT]):
@@ -177,10 +222,10 @@ class LruCache(Generic[KT, VT]):
 
         self.len = synchronized(cache_len)
 
-        def add_node(key, value, callbacks: Optional[set] = None):
+        def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
             prev_node = list_root
             next_node = prev_node.next_node
-            node = _Node(prev_node, next_node, key, value, callbacks or set())
+            node = _Node(prev_node, next_node, key, value, callbacks)
             prev_node.next_node = node
             next_node.prev_node = node
             cache[key] = node
@@ -211,16 +256,15 @@ class LruCache(Generic[KT, VT]):
                 deleted_len = size_callback(node.value)
                 cached_cache_len[0] -= deleted_len
 
-            for cb in node.callbacks:
-                cb()
-            node.callbacks.clear()
+            node.run_and_clear_callbacks()
+
             return deleted_len
 
         @overload
         def cache_get(
             key: KT,
             default: Literal[None] = None,
-            callbacks: Iterable[Callable[[], None]] = ...,
+            callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
         ) -> Optional[VT]:
             ...
@@ -229,7 +273,7 @@ class LruCache(Generic[KT, VT]):
         def cache_get(
             key: KT,
             default: T,
-            callbacks: Iterable[Callable[[], None]] = ...,
+            callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
         ) -> Union[T, VT]:
             ...
@@ -238,13 +282,13 @@ class LruCache(Generic[KT, VT]):
         def cache_get(
             key: KT,
             default: Optional[T] = None,
-            callbacks: Iterable[Callable[[], None]] = (),
+            callbacks: Collection[Callable[[], None]] = (),
             update_metrics: bool = True,
         ):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
-                node.callbacks.update(callbacks)
+                node.add_callbacks(callbacks)
                 if update_metrics and metrics:
                     metrics.inc_hits()
                 return node.value
@@ -260,10 +304,8 @@ class LruCache(Generic[KT, VT]):
                 # We sometimes store large objects, e.g. dicts, which cause
                 # the inequality check to take a long time. So let's only do
                 # the check if we have some callbacks to call.
-                if node.callbacks and value != node.value:
-                    for cb in node.callbacks:
-                        cb()
-                    node.callbacks.clear()
+                if value != node.value:
+                    node.run_and_clear_callbacks()
 
                 # We don't bother to protect this by value != node.value as
                 # generally size_callback will be cheap compared with equality
@@ -273,7 +315,7 @@ class LruCache(Generic[KT, VT]):
                     cached_cache_len[0] -= size_callback(node.value)
                     cached_cache_len[0] += size_callback(value)
 
-                node.callbacks.update(callbacks)
+                node.add_callbacks(callbacks)
 
                 move_node_to_front(node)
                 node.value = value
@@ -326,8 +368,7 @@ class LruCache(Generic[KT, VT]):
             list_root.next_node = list_root
             list_root.prev_node = list_root
             for node in cache.values():
-                for cb in node.callbacks:
-                    cb()
+                node.run_and_clear_callbacks()
             cache.clear()
             if size_callback:
                 cached_cache_len[0] = 0