summary refs log tree commit diff
diff options
context:
space:
mode:
authorAndrew Morgan <andrewm@element.io>2022-09-21 17:37:25 +0100
committerAndrew Morgan <andrewm@element.io>2022-09-21 17:37:38 +0100
commit6ff8ba5fc64f6588bf4a6b5031fc0d0603524b73 (patch)
treeedda34ff4e8c7aa6259abd3e1b57e29ed320f149
parentAdd a `MXCUri` class to make working with mxc uri's easier. (#13162) (diff)
downloadsynapse-6ff8ba5fc64f6588bf4a6b5031fc0d0603524b73.tar.xz
wip
-rw-r--r--synapse/storage/databases/main/censor_events.py2
-rw-r--r--synapse/storage/databases/main/events.py8
-rw-r--r--synapse/storage/databases/main/events_worker.py59
-rw-r--r--synapse/storage/databases/main/purge_events.py3
-rw-r--r--synapse/util/caches/dual_lookup_cache.py212
-rw-r--r--synapse/util/caches/lrucache.py53
-rw-r--r--tests/storage/test_purge.py1
7 files changed, 318 insertions, 20 deletions
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 58177ecec1..b5397475ab 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
             # changed its content in the database. We can't call
             # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
             # right type.
-            self.invalidate_get_event_cache_after_txn(txn, event.event_id)
+            self.invalidate_get_event_cache_by_event_id_after_txn(txn, event.event_id)
             # Send that invalidation to replication so that other workers also invalidate
             # the event cache.
             self._send_invalidation_to_replication(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1b54a2eb57..ddc1797ec7 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1294,8 +1294,10 @@ class PersistEventsStore:
         """
         depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
-            # Remove the any existing cache entries for the event_ids
-            self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
+            # Remove any existing cache entries for the event_ids
+            self.store.invalidate_get_event_cache_by_event_id_after_txn(
+                txn, event.event_id
+            )
             # Then update the `stream_ordering` position to mark the latest
             # event as the front of the room. This should not be done for
             # backfilled events because backfilled events have negative
@@ -1703,7 +1705,7 @@ class PersistEventsStore:
         _invalidate_caches_for_event.
         """
         assert event.redacts is not None
-        self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
+        self.store.invalidate_get_event_cache_by_event_id_after_txn(txn, event.redacts)
         txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
         txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 52914febf9..293c94946d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -80,6 +80,7 @@ from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.dual_lookup_cache import DualLookupCache
 from synapse.util.caches.lrucache import AsyncLruCache
 from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
@@ -245,6 +246,8 @@ class EventsWorkerStore(SQLBaseStore):
         ] = AsyncLruCache(
             cache_name="*getEvent*",
             max_size=hs.config.caches.event_cache_size,
+            cache_type=DualLookupCache,
+            dual_lookup_secondary_key_function=lambda v: (v.event.room_id,),
         )
 
         # Map from event ID to a deferred that will result in a map from event
@@ -733,7 +736,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_entry_map
 
-    def invalidate_get_event_cache_after_txn(
+    def invalidate_get_event_cache_by_event_id_after_txn(
         self, txn: LoggingTransaction, event_id: str
     ) -> None:
         """
@@ -747,10 +750,31 @@ class EventsWorkerStore(SQLBaseStore):
             event_id: the event ID to be invalidated from caches
         """
 
-        txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
-        txn.call_after(self._invalidate_local_get_event_cache, event_id)
+        txn.async_call_after(
+            self._invalidate_async_get_event_cache_by_event_id, event_id
+        )
+        txn.call_after(self._invalidate_local_get_event_cache_by_event_id, event_id)
 
-    async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
+    def invalidate_get_event_cache_by_room_id_after_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> None:
+        """
+        Prepares a database transaction to invalidate the get event cache for a given
+        room ID when executed successfully. This is achieved by attaching two callbacks
+        to the transaction, one to invalidate the async cache and one for the in memory
+        sync cache (importantly called in that order).
+
+        Arguments:
+            txn: the database transaction to attach the callbacks to.
+            room_id: the room ID to invalidate all associated event caches for.
+        """
+
+        txn.async_call_after(self._invalidate_async_get_event_cache_by_room_id, room_id)
+        txn.call_after(self._invalidate_local_get_event_cache_by_room_id, room_id)
+
+    async def _invalidate_async_get_event_cache_by_event_id(
+        self, event_id: str
+    ) -> None:
         """
         Invalidates an event in the asyncronous get event cache, which may be remote.
 
@@ -760,7 +784,18 @@ class EventsWorkerStore(SQLBaseStore):
 
         await self._get_event_cache.invalidate((event_id,))
 
-    def _invalidate_local_get_event_cache(self, event_id: str) -> None:
+    async def _invalidate_async_get_event_cache_by_room_id(self, room_id: str) -> None:
+        """
+        Invalidates all events associated with a given room in the asyncronous get event
+        cache, which may be remote.
+
+        Arguments:
+            room_id: the room ID to invalidate associated events of.
+        """
+
+        await self._get_event_cache.invalidate((room_id,))
+
+    def _invalidate_local_get_event_cache_by_event_id(self, event_id: str) -> None:
         """
         Invalidates an event in local in-memory get event caches.
 
@@ -772,6 +807,18 @@ class EventsWorkerStore(SQLBaseStore):
         self._event_ref.pop(event_id, None)
         self._current_event_fetches.pop(event_id, None)
 
+    def _invalidate_local_get_event_cache_by_room_id(self, room_id: str) -> None:
+        """
+        Invalidates all events associated with a given room ID in local in-memory
+        get event caches.
+
+        Arguments:
+            room_id: the room ID to invalidate events of.
+        """
+        self._get_event_cache.invalidate_local((room_id,))
+
+        # TODO: invalidate _event_ref and _current_event_fetches. How?
+
     async def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
     ) -> Dict[str, EventCacheEntry]:
@@ -2284,7 +2331,7 @@ class EventsWorkerStore(SQLBaseStore):
             updatevalues={"rejection_reason": rejection_reason},
         )
 
-        self.invalidate_get_event_cache_after_txn(txn, event_id)
+        self.invalidate_get_event_cache_by_event_id_after_txn(txn, event_id)
 
         # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
         #   call '_send_invalidation_to_replication', but we actually need the other
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index f6822707e4..0a6e7ec5ce 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -304,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
                 self._invalidate_cache_and_stream(
                     txn, self.have_seen_event, (room_id, event_id)
                 )
-                self.invalidate_get_event_cache_after_txn(txn, event_id)
+                self.invalidate_get_event_cache_by_event_id_after_txn(txn, event_id)
 
         logger.info("[purge] done")
 
@@ -478,6 +478,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         # XXX: as with purge_history, this is racy, but no worse than other races
         #   that already exist.
         self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,))
+        self._invalidate_local_get_event_cache_by_room_id(room_id)
 
         logger.info("[purge] done")
 
diff --git a/synapse/util/caches/dual_lookup_cache.py b/synapse/util/caches/dual_lookup_cache.py
new file mode 100644
index 0000000000..7529a5b8f9
--- /dev/null
+++ b/synapse/util/caches/dual_lookup_cache.py
@@ -0,0 +1,212 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import (
+    Callable,
+    Dict,
+    Generic,
+    ItemsView,
+    Optional,
+    Set,
+    TypeVar,
+    ValuesView,
+)
+
+SENTINEL = object()
+
+# The type of the primary dict's keys.
+PKT = TypeVar("PKT")
+# The type of the primary dict's values.
+PVT = TypeVar("PVT")
+# The type of the secondary dict's keys.
+SKT = TypeVar("SKT")
+
+logger = logging.getLogger(__name__)
+
+
+class DualLookupCache(Generic[PKT, PVT, SKT]):
+    """
+    A backing store for LruCache that supports multiple entry points.
+    Allows subsets of data to be deleted efficiently without requiring extra
+    information to query.
+
+    The data structure is two dictionaries:
+        * primary_dict containing a mapping of primary_key -> value.
+        * secondary_dict containing a mapping of secondary_key -> set of primary_key.
+
+    On insert, a mapping in the primary_dict must be created. A mapping in the
+    secondary_dict from a secondary_key to (a set containing) the same
+    primary_key will be made. The secondary_key
+    must be derived from the inserted value via a lambda function provided at cache
+    initialisation. This is so invalidated entries in the primary_dict may automatically
+    invalidate those in the secondary_dict. The secondary_key may be associated with one
+    or more primary_key's.
+
+    This creates an interface which allows for efficient lookups of a value given
+    a primary_key, as well as efficient invalidation of a subset of mapping in the
+    primary_dict given a secondary_key. A primary_key may not be associated with more
+    than one secondary_key.
+
+    As a worked example, consider storing a cache of room events. We could configure
+    the cache to store mappings between EventIDs and EventBase in the primary_dict,
+    while storing a mapping between room IDs and event IDs as the secondary_dict:
+
+        primary_dict: EventID -> EventBase
+        secondary_dict: RoomID -> {EventID, EventID, ...}
+
+    This would be efficient for the following operations:
+        * Given an EventID, look up the associated EventBase, and thus the roomID.
+        * Given a RoomID, invalidate all primary_dict entries for events in that room.
+
+    Since this is intended as a backing store for LRUCache, when it came time to evict
+    an entry from the primary_dict (EventID -> EventBase), the secondary_key could be
+    derived from a provided lambda function:
+        secondary_key = lambda event_base: event_base.room_id
+
+    The EventID set under room_id would then have the appropriate EventID entry evicted.
+    """
+
+    def __init__(self, secondary_key_function: Callable[[PVT], SKT]) -> None:
+        self._primary_dict: Dict[PKT, PVT] = {}
+        self._secondary_dict: Dict[SKT, Set[PKT]] = {}
+        self._secondary_key_function = secondary_key_function
+
+    def __setitem__(self, key: PKT, value: PVT) -> None:
+        self.set(key, value)
+
+    def __contains__(self, key: PKT) -> bool:
+        return key in self._primary_dict
+
+    def set(self, key: PKT, value: PVT) -> None:
+        """Add an entry to the cache.
+
+        Will add an entry to the primary_dict consisting of key->value, as well as append
+        to the set referred to by secondary_key_function(value) in the secondary_dict.
+
+        Args:
+            key: The key for a new mapping in primary_dict.
+            value: The value for a new mapping in primary_dict.
+        """
+        # Create an entry in the primary_dict.
+        self._primary_dict[key] = value
+
+        # Derive the secondary_key to use from the given primary_value.
+        secondary_key = self._secondary_key_function(value)
+
+        # TODO: If the lambda function resolves to None, don't insert an entry?
+
+        # And create a mapping in the secondary_dict to a set containing the
+        # primary_key, creating the set if necessary.
+        secondary_key_set = self._secondary_dict.setdefault(secondary_key, set())
+        secondary_key_set.add(key)
+
+        logger.info("*** Insert into primary_dict: %s: %s", key, value)
+        logger.info("*** Insert into secondary_dict: %s: %s", secondary_key, key)
+
+    def get(self, key: PKT, default: Optional[PVT] = None) -> Optional[PVT]:
+        """Retrieve a value from the cache if it exists. If not, return the default
+        value.
+
+        This method simply pulls entries from the primary_dict.
+
+        # TODO: Any use cases for externally getting entries from the secondary_dict?
+
+        Args:
+            key: The key to search the cache for.
+            default: The default value to return if the given key is not found.
+
+        Returns:
+            The value referenced by the given key, if it exists in the cache. If not,
+            the value of `default` will be returned.
+        """
+        logger.info("*** Retrieving key from primary_dict: %s", key)
+        return self._primary_dict.get(key, default)
+
+    def clear(self) -> None:
+        """Evicts all entries from the cache."""
+        self._primary_dict.clear()
+        self._secondary_dict.clear()
+
+    def pop(self, key: PKT, default: Optional[PVT] = None) -> Optional[PVT]:
+        """Remove the given key, from the cache if it exists, and return the associated
+        value.
+
+        Evicts an entry from both the primary_dict and secondary_dict.
+
+        Args:
+            key: The key to remove from the cache.
+            default: The value to return if the given key is not found.
+
+        Returns:
+            The value associated with the given key if it is found. Otherwise, the value
+            of `default`.
+        """
+        # Exit immediately if the key is not found
+        if key not in self._primary_dict:
+            return default
+
+        # Pop the entry from the primary_dict to retrieve the desired value
+        primary_value = self._primary_dict.pop(key)
+
+        logger.info("*** Popping from primary_dict: %s: %s", key, primary_value)
+
+        # Derive the secondary_key from the primary_value
+        secondary_key = self._secondary_key_function(primary_value)
+
+        # Pop the entry from the secondary_dict
+        secondary_key_set = self._secondary_dict[secondary_key]
+        if len(secondary_key_set) > 1:
+            # Delete just the set entry for the given key.
+            secondary_key_set.remove(key)
+            logger.info("*** Popping from secondary_dict: %s: %s", secondary_key, key)
+
+        else:
+            # Delete the entire soon-to-be-empty set referenced by the secondary_key.
+            del self._secondary_dict[secondary_key]
+            logger.info("*** Popping from secondary_dict: %s", secondary_key)
+
+        return primary_value
+
+    def del_multi(self, secondary_key: SKT) -> None:
+        """Remove an entry from the secondary_dict, removing all associated entries
+        in the primary_dict as well.
+
+        Args:
+            secondary_key: A secondary_key to drop. May be associated with zero or more
+                primary keys. If any associated primary keys are found, they will be
+                dropped as well.
+        """
+        primary_key_set = self._secondary_dict.pop(secondary_key, None)
+        if not primary_key_set:
+            logger.info(
+                "*** Did not find '%s' in secondary_dict: %s",
+                secondary_key,
+                self._secondary_dict,
+            )
+            return
+
+        logger.info("*** Popping whole key from secondary_dict: %s", secondary_key)
+        for primary_key in primary_key_set:
+            logger.info("*** Popping entry from primary_dict: %s", primary_key)
+            logger.info("*** primary_dict: %s", self._primary_dict)
+            del self._primary_dict[primary_key]
+
+    def values(self) -> ValuesView:
+        return self._primary_dict.values()
+
+    def items(self) -> ItemsView:
+        return self._primary_dict.items()
+
+    def __len__(self) -> int:
+        return len(self._primary_dict)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index aa93109d13..30765e630d 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -46,6 +46,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
 from synapse.metrics.jemalloc import get_jemalloc_stats
 from synapse.util import Clock, caches
 from synapse.util.caches import CacheMetric, EvictionReason, register_cache
+from synapse.util.caches.dual_lookup_cache import DualLookupCache
 from synapse.util.caches.treecache import (
     TreeCache,
     iterate_tree_cache_entry,
@@ -375,12 +376,13 @@ class LruCache(Generic[KT, VT]):
         self,
         max_size: int,
         cache_name: Optional[str] = None,
-        cache_type: Type[Union[dict, TreeCache]] = dict,
+        cache_type: Type[Union[dict, TreeCache, DualLookupCache]] = dict,
         size_callback: Optional[Callable[[VT], int]] = None,
         metrics_collection_callback: Optional[Callable[[], None]] = None,
         apply_cache_factor_from_config: bool = True,
         clock: Optional[Clock] = None,
         prune_unread_entries: bool = True,
+        dual_lookup_secondary_key_function: Optional[Callable[[Any], Any]] = None,
     ):
         """
         Args:
@@ -411,6 +413,10 @@ class LruCache(Generic[KT, VT]):
             prune_unread_entries: If True, cache entries that haven't been read recently
                 will be evicted from the cache in the background. Set to False to
                 opt-out of this behaviour.
+
+            # TODO: At this point we should probably just pass an initialised cache type
+            # to LruCache, no?
+            dual_lookup_secondary_key_function:
         """
         # Default `clock` to something sensible. Note that we rename it to
         # `real_clock` so that mypy doesn't think its still `Optional`.
@@ -419,7 +425,30 @@ class LruCache(Generic[KT, VT]):
         else:
             real_clock = clock
 
-        cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
+        # TODO: I've had to make this ugly to appease mypy :(
+        # Perhaps initialise the backing cache and then pass to LruCache?
+        cache: Union[Dict[KT, _Node[KT, VT]], TreeCache, DualLookupCache]
+        if cache_type is DualLookupCache:
+            # The dual_lookup_secondary_key_function is a function that's intended to
+            # extract a key from the value in the cache. Since we wrap values given to
+            # us in a _Node object, this function will actually operate on a _Node,
+            # instead of directly on the object type callers are expecting.
+            #
+            # Thus, we wrap the function given by the caller in another one that
+            # extracts the value from the _Node, before then handing it off to the
+            # given function for processing.
+            def key_function_wrapper(node: Any) -> Any:
+                assert dual_lookup_secondary_key_function is not None
+                return dual_lookup_secondary_key_function(node.value)
+
+            cache = DualLookupCache(
+                secondary_key_function=key_function_wrapper,
+            )
+        elif cache_type is TreeCache:
+            cache = TreeCache()
+        else:
+            cache = {}
+
         self.cache = cache  # Used for introspection.
         self.apply_cache_factor_from_config = apply_cache_factor_from_config
 
@@ -722,13 +751,21 @@ class LruCache(Generic[KT, VT]):
             may be of lower cardinality than the TreeCache - in which case the whole
             subtree is deleted.
             """
-            popped = cache.pop(key, None)
-            if popped is None:
+            if isinstance(cache, DualLookupCache):
+                # Make use of DualLookupCache's del_multi feature
+                cache.del_multi(key)
                 return
-            # for each deleted node, we now need to remove it from the linked list
-            # and run its callbacks.
-            for leaf in iterate_tree_cache_entry(popped):
-                delete_node(leaf)
+
+            # Remove an entry from the cache.
+            # In the case of a 'dict' cache type, we're just removing an entry from the
+            # dict. For a TreeCache, we're removing a subtree which has children.
+            popped_entry = cache.pop(key, None)
+            if popped_entry is not None and cache_type is TreeCache:
+                # We've popped a subtree - now we need to clean up each child node.
+                # For each deleted node, we remove it from the linked list and run
+                # its callbacks.
+                for leaf in iterate_tree_cache_entry(popped_entry):
+                    delete_node(leaf)
 
         @synchronized
         def cache_clear() -> None:
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 9c1182ed16..4782cac4fc 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -115,6 +115,5 @@ class PurgeTests(HomeserverTestCase):
         )
 
         # The events aren't found.
-        self.store._invalidate_local_get_event_cache(create_event.event_id)
         self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
         self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)