summary refs log tree commit diff
diff options
context:
space:
mode:
authorAndrew Morgan <andrewm@element.io>2022-09-22 15:54:30 +0100
committerAndrew Morgan <andrewm@element.io>2022-09-22 15:54:30 +0100
commitf1d98d3b708a4b05b33d1974d0d527bd32fff211 (patch)
treebffabb35f9c97fac6e8fb58d7565735e1f824947
parentwip (diff)
downloadsynapse-f1d98d3b708a4b05b33d1974d0d527bd32fff211.tar.xz
wip2
-rw-r--r--synapse/util/caches/dual_lookup_cache.py134
-rw-r--r--synapse/util/caches/lrucache.py27
2 files changed, 96 insertions, 65 deletions
diff --git a/synapse/util/caches/dual_lookup_cache.py b/synapse/util/caches/dual_lookup_cache.py
index 7529a5b8f9..6ee72de705 100644
--- a/synapse/util/caches/dual_lookup_cache.py
+++ b/synapse/util/caches/dual_lookup_cache.py
@@ -17,12 +17,14 @@ from typing import (
     Dict,
     Generic,
     ItemsView,
+    List,
     Optional,
-    Set,
     TypeVar,
+    Union,
     ValuesView,
 )
 
+# Used to discern between a value not existing in a map, or the value being 'None'.
 SENTINEL = object()
 
 # The type of the primary dict's keys.
@@ -35,6 +37,13 @@ SKT = TypeVar("SKT")
 logger = logging.getLogger(__name__)
 
 
+class SecondarySet(set):
+    """
+    Used to differentiate between an entry in the secondary_dict, and a set stored
+    in the primary_dict. This is necessary as pop() can return either.
+    """
+
+
 class DualLookupCache(Generic[PKT, PVT, SKT]):
     """
     A backing store for LruCache that supports multiple entry points.
@@ -79,7 +88,7 @@ class DualLookupCache(Generic[PKT, PVT, SKT]):
 
     def __init__(self, secondary_key_function: Callable[[PVT], SKT]) -> None:
         self._primary_dict: Dict[PKT, PVT] = {}
-        self._secondary_dict: Dict[SKT, Set[PKT]] = {}
+        self._secondary_dict: Dict[SKT, SecondarySet] = {}
         self._secondary_key_function = secondary_key_function
 
     def __setitem__(self, key: PKT, value: PVT) -> None:
@@ -108,7 +117,9 @@ class DualLookupCache(Generic[PKT, PVT, SKT]):
 
         # And create a mapping in the secondary_dict to a set containing the
         # primary_key, creating the set if necessary.
-        secondary_key_set = self._secondary_dict.setdefault(secondary_key, set())
+        secondary_key_set = self._secondary_dict.setdefault(
+            secondary_key, SecondarySet()
+        )
         secondary_key_set.add(key)
 
         logger.info("*** Insert into primary_dict: %s: %s", key, value)
@@ -138,69 +149,84 @@ class DualLookupCache(Generic[PKT, PVT, SKT]):
         self._primary_dict.clear()
         self._secondary_dict.clear()
 
-    def pop(self, key: PKT, default: Optional[PVT] = None) -> Optional[PVT]:
-        """Remove the given key, from the cache if it exists, and return the associated
-        value.
+    def pop(
+        self, key: Union[PKT, SKT], default: Optional[Union[Dict[PKT, PVT], PVT]] = None
+    ) -> Optional[Union[Dict[PKT, PVT], PVT]]:
+        """Remove an entry from either the primary_dict or secondary_dict.
 
-        Evicts an entry from both the primary_dict and secondary_dict.
+        The primary_dict is checked first for the key. If an entry is found, it is
+        removed from the primary_dict and returned.
+
+        If no entry in the primary_dict exists, then the secondary_dict is checked.
+        If an entry exists, all associated entries in the primary_dict will be
+        deleted, and all primary_dict keys returned from this function in a SecondarySet.
 
         Args:
-            key: The key to remove from the cache.
-            default: The value to return if the given key is not found.
+            key: A key to drop from either the primary_dict or secondary_dict.
+            default: The default value if the key does not exist in either dict.
 
         Returns:
-            The value associated with the given key if it is found. Otherwise, the value
-            of `default`.
+            Either a matched value from the primary_dict or the secondary_dict. If no
+            value is found for the key, then None.
         """
-        # Exit immediately if the key is not found
-        if key not in self._primary_dict:
-            return default
-
-        # Pop the entry from the primary_dict to retrieve the desired value
-        primary_value = self._primary_dict.pop(key)
-
-        logger.info("*** Popping from primary_dict: %s: %s", key, primary_value)
-
-        # Derive the secondary_key from the primary_value
-        secondary_key = self._secondary_key_function(primary_value)
-
-        # Pop the entry from the secondary_dict
-        secondary_key_set = self._secondary_dict[secondary_key]
-        if len(secondary_key_set) > 1:
-            # Delete just the set entry for the given key.
-            secondary_key_set.remove(key)
-            logger.info("*** Popping from secondary_dict: %s: %s", secondary_key, key)
-
-        else:
-            # Delete the entire soon-to-be-empty set referenced by the secondary_key.
-            del self._secondary_dict[secondary_key]
-            logger.info("*** Popping from secondary_dict: %s", secondary_key)
+        # Attempt to remove from the primary_dict first.
+        primary_value = self._primary_dict.pop(key, SENTINEL)
+        if primary_value is not SENTINEL:
+            # We found a value in the primary_dict. Remove it from the corresponding
+            # entry in the secondary_dict, and then return it.
+            logger.info(
+                "*** Popped entry from primary_dict: %s: %s", key, primary_value
+            )
 
-        return primary_value
+            # Derive the secondary_key from the primary_value
+            secondary_key = self._secondary_key_function(primary_value)
+
+            # Pop the entry from the secondary_dict
+            secondary_key_set = self._secondary_dict[secondary_key]
+            if len(secondary_key_set) > 1:
+                # Delete just the set entry for the given key.
+                secondary_key_set.remove(key)
+                logger.info(
+                    "*** Popping from secondary_dict: %s: %s", secondary_key, key
+                )
+            else:
+                # Delete the entire set referenced by the secondary_key, as it only
+                # has one entry.
+                del self._secondary_dict[secondary_key]
+                logger.info("*** Popping from secondary_dict: %s", secondary_key)
+
+            return primary_value
+
+        # There was no matching value in the primary_dict. Attempt the secondary_dict.
+        primary_key_set = self._secondary_dict.pop(key, SENTINEL)
+        if primary_key_set is not SENTINEL:
+            # We found a set in the secondary_dict.
+            logger.info(
+                "*** Found '%s' in secondary_dict: %s: ",
+                key,
+                primary_key_set,
+            )
 
-    def del_multi(self, secondary_key: SKT) -> None:
-        """Remove an entry from the secondary_dict, removing all associated entries
-        in the primary_dict as well.
+            popped_primary_dict_values: List[PVT] = []
 
-        Args:
-            secondary_key: A secondary_key to drop. May be associated with zero or more
-                primary keys. If any associated primary keys are found, they will be
-                dropped as well.
-        """
-        primary_key_set = self._secondary_dict.pop(secondary_key, None)
-        if not primary_key_set:
+            # We found an entry in the secondary_dict. Delete all related entries in the
+            # primary_dict.
             logger.info(
-                "*** Did not find '%s' in secondary_dict: %s",
-                secondary_key,
-                self._secondary_dict,
+                "*** Found key in secondary_dict to pop: %s. "
+                "Popping primary_dict entries",
+                key,
             )
-            return
+            for primary_key in primary_key_set:
+                logger.info("*** Popping entry from primary_dict: %s", primary_key)
+                logger.info("*** primary_dict: %s", self._primary_dict)
+                popped_primary_dict_values = self._primary_dict[primary_key]
+                del self._primary_dict[primary_key]
+
+            # Now return the unmodified copy of the set.
+            return popped_primary_dict_values
 
-        logger.info("*** Popping whole key from secondary_dict: %s", secondary_key)
-        for primary_key in primary_key_set:
-            logger.info("*** Popping entry from primary_dict: %s", primary_key)
-            logger.info("*** primary_dict: %s", self._primary_dict)
-            del self._primary_dict[primary_key]
+        # No match in either dict.
+        return default
 
     def values(self) -> ValuesView:
         return self._primary_dict.values()
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 30765e630d..6345735680 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -46,9 +46,10 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
 from synapse.metrics.jemalloc import get_jemalloc_stats
 from synapse.util import Clock, caches
 from synapse.util.caches import CacheMetric, EvictionReason, register_cache
-from synapse.util.caches.dual_lookup_cache import DualLookupCache
+from synapse.util.caches.dual_lookup_cache import DualLookupCache, SecondarySet
 from synapse.util.caches.treecache import (
     TreeCache,
+    TreeCacheNode,
     iterate_tree_cache_entry,
     iterate_tree_cache_items,
 )
@@ -751,21 +752,25 @@ class LruCache(Generic[KT, VT]):
             may be of lower cardinality than the TreeCache - in which case the whole
             subtree is deleted.
             """
-            if isinstance(cache, DualLookupCache):
-                # Make use of DualLookupCache's del_multi feature
-                cache.del_multi(key)
-                return
-
             # Remove an entry from the cache.
             # In the case of a 'dict' cache type, we're just removing an entry from the
             # dict. For a TreeCache, we're removing a subtree which has children.
-            popped_entry = cache.pop(key, None)
-            if popped_entry is not None and cache_type is TreeCache:
-                # We've popped a subtree - now we need to clean up each child node.
-                # For each deleted node, we remove it from the linked list and run
-                # its callbacks.
+            popped_entry: _Node[KT, VT] = cache.pop(key, None)
+            if popped_entry is None:
+                return
+
+            if isinstance(popped_entry, TreeCacheNode):
+                # We've popped a subtree from a TreeCache - now we need to clean up
+                # each child node.
                 for leaf in iterate_tree_cache_entry(popped_entry):
+                    # For each deleted child node, we remove it from the linked list and
+                    # run its callbacks.
+                    delete_node(leaf)
+            elif isinstance(popped_entry, SecondarySet):
+                for leaf in popped_entry:
                     delete_node(leaf)
+            else:
+                delete_node(popped_entry)
 
         @synchronized
         def cache_clear() -> None: