summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/_base.py11
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/cache.py70
-rw-r--r--synapse/util/caches/lrucache.py237
-rw-r--r--synapse/util/linked_list.py150
5 files changed, 404 insertions, 66 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 8879136881..b30571fe49 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -21,7 +21,7 @@ import socket
 import sys
 import traceback
 import warnings
-from typing import Awaitable, Callable, Iterable
+from typing import TYPE_CHECKING, Awaitable, Callable, Iterable
 
 from cryptography.utils import CryptographyDeprecationWarning
 from typing_extensions import NoReturn
@@ -41,10 +41,14 @@ from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.logging.context import PreserveLoggingContext
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.jemalloc import setup_jemalloc_stats
+from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
 from synapse.util.daemonize import daemonize_process
 from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # list of tuples of function, args list, kwargs dict
@@ -312,7 +316,7 @@ def refresh_certificate(hs):
         logger.info("Context factories updated.")
 
 
-async def start(hs: "synapse.server.HomeServer"):
+async def start(hs: "HomeServer"):
     """
     Start a Synapse server or worker.
 
@@ -365,6 +369,9 @@ async def start(hs: "synapse.server.HomeServer"):
 
     load_legacy_spam_checkers(hs)
 
+    # If we've configured an expiry time for caches, start the background job now.
+    setup_expire_lru_cache_entries(hs)
+
     # It is now safe to start your Synapse.
     hs.start_listening()
     hs.get_datastore().db_pool.start_profiling()
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 23ca0c83c1..06fbd1166b 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -5,6 +5,7 @@ from synapse.config import (
     api,
     appservice,
     auth,
+    cache,
     captcha,
     cas,
     consent,
@@ -88,6 +89,7 @@ class RootConfig:
     tracer: tracer.TracerConfig
     redis: redis.RedisConfig
     modules: modules.ModulesConfig
+    caches: cache.CacheConfig
     federation: federation.FederationConfig
 
     config_classes: List = ...
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 91165ee1ce..7789b40323 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -116,35 +116,41 @@ class CacheConfig(Config):
         #event_cache_size: 10K
 
         caches:
-           # Controls the global cache factor, which is the default cache factor
-           # for all caches if a specific factor for that cache is not otherwise
-           # set.
-           #
-           # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
-           # variable. Setting by environment variable takes priority over
-           # setting through the config file.
-           #
-           # Defaults to 0.5, which will half the size of all caches.
-           #
-           #global_factor: 1.0
-
-           # A dictionary of cache name to cache factor for that individual
-           # cache. Overrides the global cache factor for a given cache.
-           #
-           # These can also be set through environment variables comprised
-           # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
-           # letters and underscores. Setting by environment variable
-           # takes priority over setting through the config file.
-           # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
-           #
-           # Some caches have '*' and other characters that are not
-           # alphanumeric or underscores. These caches can be named with or
-           # without the special characters stripped. For example, to specify
-           # the cache factor for `*stateGroupCache*` via an environment
-           # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`.
-           #
-           per_cache_factors:
-             #get_users_who_share_room_with_user: 2.0
+          # Controls the global cache factor, which is the default cache factor
+          # for all caches if a specific factor for that cache is not otherwise
+          # set.
+          #
+          # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
+          # variable. Setting by environment variable takes priority over
+          # setting through the config file.
+          #
+          # Defaults to 0.5, which will half the size of all caches.
+          #
+          #global_factor: 1.0
+
+          # A dictionary of cache name to cache factor for that individual
+          # cache. Overrides the global cache factor for a given cache.
+          #
+          # These can also be set through environment variables comprised
+          # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
+          # letters and underscores. Setting by environment variable
+          # takes priority over setting through the config file.
+          # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
+          #
+          # Some caches have '*' and other characters that are not
+          # alphanumeric or underscores. These caches can be named with or
+          # without the special characters stripped. For example, to specify
+          # the cache factor for `*stateGroupCache*` via an environment
+          # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`.
+          #
+          per_cache_factors:
+            #get_users_who_share_room_with_user: 2.0
+
+          # Controls how long an entry can be in a cache without having been
+          # accessed before being evicted. Defaults to None, which means
+          # entries are never evicted based on time.
+          #
+          #expiry_time: 30m
         """
 
     def read_config(self, config, **kwargs):
@@ -200,6 +206,12 @@ class CacheConfig(Config):
                     e.message  # noqa: B306, DependencyException.message is a property
                 )
 
+        expiry_time = cache_config.get("expiry_time")
+        if expiry_time:
+            self.expiry_time_msec = self.parse_duration(expiry_time)
+        else:
+            self.expiry_time_msec = None
+
         # Resize all caches (if necessary) with the new factors we've loaded
         self.resize_all_caches()
 
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index d89e9d9b1d..4b9d0433ff 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -12,9 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
 import threading
+import weakref
 from functools import wraps
 from typing import (
+    TYPE_CHECKING,
     Any,
     Callable,
     Collection,
@@ -31,10 +34,19 @@ from typing import (
 
 from typing_extensions import Literal
 
+from twisted.internet import reactor
+
 from synapse.config import cache as cache_config
-from synapse.util import caches
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.util import Clock, caches
 from synapse.util.caches import CacheMetric, register_cache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.linked_list import ListNode
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
 
 try:
     from pympler.asizeof import Asizer
@@ -82,19 +94,126 @@ def enumerate_leaves(node, depth):
                 yield m
 
 
+P = TypeVar("P")
+
+
+class _TimedListNode(ListNode[P]):
+    """A `ListNode` that tracks last access time."""
+
+    __slots__ = ["last_access_ts_secs"]
+
+    def update_last_access(self, clock: Clock):
+        self.last_access_ts_secs = int(clock.time())
+
+
+# Whether to insert new cache entries to the global list. We only add to it if
+# time based eviction is enabled.
+USE_GLOBAL_LIST = False
+
+# A linked list of all cache entries, allowing efficient time based eviction.
+GLOBAL_ROOT = ListNode["_Node"].create_root_node()
+
+
+@wrap_as_background_process("LruCache._expire_old_entries")
+async def _expire_old_entries(clock: Clock, expiry_seconds: int):
+    """Walks the global cache list to find cache entries that haven't been
+    accessed in the given number of seconds.
+    """
+
+    now = int(clock.time())
+    node = GLOBAL_ROOT.prev_node
+    assert node is not None
+
+    i = 0
+
+    logger.debug("Searching for stale caches")
+
+    while node is not GLOBAL_ROOT:
+        # Only the root node isn't a `_TimedListNode`.
+        assert isinstance(node, _TimedListNode)
+
+        if node.last_access_ts_secs > now - expiry_seconds:
+            break
+
+        cache_entry = node.get_cache_entry()
+        next_node = node.prev_node
+
+        # The node should always have a reference to a cache entry and a valid
+        # `prev_node`, as we only drop them when we remove the node from the
+        # list.
+        assert next_node is not None
+        assert cache_entry is not None
+        cache_entry.drop_from_cache()
+
+        # If we do lots of work at once we yield to allow other stuff to happen.
+        if (i + 1) % 10000 == 0:
+            logger.debug("Waiting during drop")
+            await clock.sleep(0)
+            logger.debug("Waking during drop")
+
+        node = next_node
+
+        # If we've yielded then our current node may have been evicted, so we
+        # need to check that its still valid.
+        if node.prev_node is None:
+            break
+
+        i += 1
+
+    logger.info("Dropped %d items from caches", i)
+
+
+def setup_expire_lru_cache_entries(hs: "HomeServer"):
+    """Start a background job that expires all cache entries if they have not
+    been accessed for the given number of seconds.
+    """
+    if not hs.config.caches.expiry_time_msec:
+        return
+
+    logger.info(
+        "Expiring LRU caches after %d seconds", hs.config.caches.expiry_time_msec / 1000
+    )
+
+    global USE_GLOBAL_LIST
+    USE_GLOBAL_LIST = True
+
+    clock = hs.get_clock()
+    clock.looping_call(
+        _expire_old_entries, 30 * 1000, clock, hs.config.caches.expiry_time_msec / 1000
+    )
+
+
 class _Node:
-    __slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"]
+    __slots__ = [
+        "_list_node",
+        "_global_list_node",
+        "_cache",
+        "key",
+        "value",
+        "callbacks",
+        "memory",
+    ]
 
     def __init__(
         self,
-        prev_node,
-        next_node,
+        root: "ListNode[_Node]",
         key,
         value,
+        cache: "weakref.ReferenceType[LruCache]",
+        clock: Clock,
         callbacks: Collection[Callable[[], None]] = (),
     ):
-        self.prev_node = prev_node
-        self.next_node = next_node
+        self._list_node = ListNode.insert_after(self, root)
+        self._global_list_node = None
+        if USE_GLOBAL_LIST:
+            self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT)
+            self._global_list_node.update_last_access(clock)
+
+        # We store a weak reference to the cache object so that this _Node can
+        # remove itself from the cache. If the cache is dropped we ensure we
+        # remove our entries in the lists.
+        self._cache = cache
+
         self.key = key
         self.value = value
 
@@ -116,11 +235,16 @@ class _Node:
             self.memory = (
                 _get_size_of(key)
                 + _get_size_of(value)
+                + _get_size_of(self._list_node, recurse=False)
                 + _get_size_of(self.callbacks, recurse=False)
                 + _get_size_of(self, recurse=False)
             )
             self.memory += _get_size_of(self.memory, recurse=False)
 
+            if self._global_list_node:
+                self.memory += _get_size_of(self._global_list_node, recurse=False)
+                self.memory += _get_size_of(self._global_list_node.last_access_ts_secs)
+
     def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
         """Add to stored list of callbacks, removing duplicates."""
 
@@ -147,6 +271,32 @@ class _Node:
 
         self.callbacks = None
 
+    def drop_from_cache(self) -> None:
+        """Drop this node from the cache.
+
+        Ensures that the entry gets removed from the cache and that we get
+        removed from all lists.
+        """
+        cache = self._cache()
+        if not cache or not cache.pop(self.key, None):
+            # `cache.pop` should call `drop_from_lists()`, unless this Node had
+            # already been removed from the cache.
+            self.drop_from_lists()
+
+    def drop_from_lists(self) -> None:
+        """Remove this node from the cache lists."""
+        self._list_node.remove_from_list()
+
+        if self._global_list_node:
+            self._global_list_node.remove_from_list()
+
+    def move_to_front(self, clock: Clock, cache_list_root: ListNode) -> None:
+        """Moves this node to the front of all the lists its in."""
+        self._list_node.move_after(cache_list_root)
+        if self._global_list_node:
+            self._global_list_node.move_after(GLOBAL_ROOT)
+            self._global_list_node.update_last_access(clock)
+
 
 class LruCache(Generic[KT, VT]):
     """
@@ -163,6 +313,7 @@ class LruCache(Generic[KT, VT]):
         size_callback: Optional[Callable] = None,
         metrics_collection_callback: Optional[Callable[[], None]] = None,
         apply_cache_factor_from_config: bool = True,
+        clock: Optional[Clock] = None,
     ):
         """
         Args:
@@ -188,6 +339,13 @@ class LruCache(Generic[KT, VT]):
             apply_cache_factor_from_config (bool): If true, `max_size` will be
                 multiplied by a cache factor derived from the homeserver config
         """
+        # Default `clock` to something sensible. Note that we rename it to
+        # `real_clock` so that mypy doesn't think its still `Optional`.
+        if clock is None:
+            real_clock = Clock(reactor)
+        else:
+            real_clock = clock
+
         cache = cache_type()
         self.cache = cache  # Used for introspection.
         self.apply_cache_factor_from_config = apply_cache_factor_from_config
@@ -219,17 +377,31 @@ class LruCache(Generic[KT, VT]):
         # this is exposed for access from outside this class
         self.metrics = metrics
 
-        list_root = _Node(None, None, None, None)
-        list_root.next_node = list_root
-        list_root.prev_node = list_root
+        # We create a single weakref to self here so that we don't need to keep
+        # creating more each time we create a `_Node`.
+        weak_ref_to_self = weakref.ref(self)
+
+        list_root = ListNode[_Node].create_root_node()
 
         lock = threading.Lock()
 
         def evict():
             while cache_len() > self.max_size:
+                # Get the last node in the list (i.e. the oldest node).
                 todelete = list_root.prev_node
-                evicted_len = delete_node(todelete)
-                cache.pop(todelete.key, None)
+
+                # The list root should always have a valid `prev_node` if the
+                # cache is not empty.
+                assert todelete is not None
+
+                # The node should always have a reference to a cache entry, as
+                # we only drop the cache entry when we remove the node from the
+                # list.
+                node = todelete.get_cache_entry()
+                assert node is not None
+
+                evicted_len = delete_node(node)
+                cache.pop(node.key, None)
                 if metrics:
                     metrics.inc_evictions(evicted_len)
 
@@ -255,11 +427,7 @@ class LruCache(Generic[KT, VT]):
         self.len = synchronized(cache_len)
 
         def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
-            prev_node = list_root
-            next_node = prev_node.next_node
-            node = _Node(prev_node, next_node, key, value, callbacks)
-            prev_node.next_node = node
-            next_node.prev_node = node
+            node = _Node(list_root, key, value, weak_ref_to_self, real_clock, callbacks)
             cache[key] = node
 
             if size_callback:
@@ -268,23 +436,11 @@ class LruCache(Generic[KT, VT]):
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.inc_memory_usage(node.memory)
 
-        def move_node_to_front(node):
-            prev_node = node.prev_node
-            next_node = node.next_node
-            prev_node.next_node = next_node
-            next_node.prev_node = prev_node
-            prev_node = list_root
-            next_node = prev_node.next_node
-            node.prev_node = prev_node
-            node.next_node = next_node
-            prev_node.next_node = node
-            next_node.prev_node = node
-
-        def delete_node(node):
-            prev_node = node.prev_node
-            next_node = node.next_node
-            prev_node.next_node = next_node
-            next_node.prev_node = prev_node
+        def move_node_to_front(node: _Node):
+            node.move_to_front(real_clock, list_root)
+
+        def delete_node(node: _Node) -> int:
+            node.drop_from_lists()
 
             deleted_len = 1
             if size_callback:
@@ -411,10 +567,13 @@ class LruCache(Generic[KT, VT]):
 
         @synchronized
         def cache_clear() -> None:
-            list_root.next_node = list_root
-            list_root.prev_node = list_root
             for node in cache.values():
                 node.run_and_clear_callbacks()
+                node.drop_from_lists()
+
+            assert list_root.next_node == list_root
+            assert list_root.prev_node == list_root
+
             cache.clear()
             if size_callback:
                 cached_cache_len[0] = 0
@@ -484,3 +643,11 @@ class LruCache(Generic[KT, VT]):
                 self._on_resize()
             return True
         return False
+
+    def __del__(self) -> None:
+        # We're about to be deleted, so we make sure to clear up all the nodes
+        # and run callbacks, etc.
+        #
+        # This happens e.g. in the sync code where we have an expiring cache of
+        # lru caches.
+        self.clear()
diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py
new file mode 100644
index 0000000000..a456b136f0
--- /dev/null
+++ b/synapse/util/linked_list.py
@@ -0,0 +1,150 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A circular doubly linked list implementation.
+"""
+
+import threading
+from typing import Generic, Optional, Type, TypeVar
+
+P = TypeVar("P")
+LN = TypeVar("LN", bound="ListNode")
+
+
+class ListNode(Generic[P]):
+    """A node in a circular doubly linked list, with an (optional) reference to
+    a cache entry.
+
+    The reference should only be `None` for the root node or if the node has
+    been removed from the list.
+    """
+
+    # A lock to protect mutating the list prev/next pointers.
+    _LOCK = threading.Lock()
+
+    # We don't use attrs here as in py3.6 you can't have `attr.s(slots=True)`
+    # and inherit from `Generic` for some reason
+    __slots__ = [
+        "cache_entry",
+        "prev_node",
+        "next_node",
+    ]
+
+    def __init__(self, cache_entry: Optional[P] = None) -> None:
+        self.cache_entry = cache_entry
+        self.prev_node: Optional[ListNode[P]] = None
+        self.next_node: Optional[ListNode[P]] = None
+
+    @classmethod
+    def create_root_node(cls: Type["ListNode[P]"]) -> "ListNode[P]":
+        """Create a new linked list by creating a "root" node, which is a node
+        that has prev_node/next_node pointing to itself and no associated cache
+        entry.
+        """
+        root = cls()
+        root.prev_node = root
+        root.next_node = root
+        return root
+
+    @classmethod
+    def insert_after(
+        cls: Type[LN],
+        cache_entry: P,
+        node: "ListNode[P]",
+    ) -> LN:
+        """Create a new list node that is placed after the given node.
+
+        Args:
+            cache_entry: The associated cache entry.
+            node: The existing node in the list to insert the new entry after.
+        """
+        new_node = cls(cache_entry)
+        with cls._LOCK:
+            new_node._refs_insert_after(node)
+        return new_node
+
+    def remove_from_list(self):
+        """Remove this node from the list."""
+        with self._LOCK:
+            self._refs_remove_node_from_list()
+
+        # We drop the reference to the cache entry to break the reference cycle
+        # between the list node and cache entry, allowing the two to be dropped
+        # immediately rather than at the next GC.
+        self.cache_entry = None
+
+    def move_after(self, node: "ListNode"):
+        """Move this node from its current location in the list to after the
+        given node.
+        """
+        with self._LOCK:
+            # We assert that both this node and the target node is still "alive".
+            assert self.prev_node
+            assert self.next_node
+            assert node.prev_node
+            assert node.next_node
+
+            assert self is not node
+
+            # Remove self from the list
+            self._refs_remove_node_from_list()
+
+            # Insert self back into the list, after target node
+            self._refs_insert_after(node)
+
+    def _refs_remove_node_from_list(self):
+        """Internal method to *just* remove the node from the list, without
+        e.g. clearing out the cache entry.
+        """
+        if self.prev_node is None or self.next_node is None:
+            # We've already been removed from the list.
+            return
+
+        prev_node = self.prev_node
+        next_node = self.next_node
+
+        prev_node.next_node = next_node
+        next_node.prev_node = prev_node
+
+        # We set these to None so that we don't get circular references,
+        # allowing us to be dropped without having to go via the GC.
+        self.prev_node = None
+        self.next_node = None
+
+    def _refs_insert_after(self, node: "ListNode"):
+        """Internal method to insert the node after the given node."""
+
+        # This method should only be called when we're not already in the list.
+        assert self.prev_node is None
+        assert self.next_node is None
+
+        # We expect the given node to be in the list and thus have valid
+        # prev/next refs.
+        assert node.next_node
+        assert node.prev_node
+
+        prev_node = node
+        next_node = node.next_node
+
+        self.prev_node = prev_node
+        self.next_node = next_node
+
+        prev_node.next_node = self
+        next_node.prev_node = self
+
+    def get_cache_entry(self) -> Optional[P]:
+        """Get the cache entry, returns None if this is the root node (i.e.
+        cache_entry is None) or if the entry has been dropped.
+        """
+        return self.cache_entry