summary refs log tree commit diff
path: root/synapse/util/caches/treecache.py
blob: b3bc0493006ace26ae32152d727d0ec24c3e6bc9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#

SENTINEL = object()


class TreeCacheNode(dict):
    """The type of nodes in our tree.

    Has its own type so we can distinguish it from real dicts that are stored at the
    leaves.
    """


class TreeCache:
    """
    Tree-based backing store for LruCache. Allows subtrees of data to be deleted
    efficiently.
    Keys must be tuples.

    The data structure is a chain of TreeCacheNodes:
        root = {key_1: {key_2: _value}}
    """

    def __init__(self) -> None:
        self.size: int = 0
        self.root = TreeCacheNode()

    def __setitem__(self, key, value) -> None:
        self.set(key, value)

    def __contains__(self, key) -> bool:
        return self.get(key, SENTINEL) is not SENTINEL

    def set(self, key, value) -> None:
        if isinstance(value, TreeCacheNode):
            # this would mean we couldn't tell where our tree ended and the value
            # started.
            raise ValueError("Cannot store TreeCacheNodes in a TreeCache")

        node = self.root
        for k in key[:-1]:
            next_node = node.get(k, SENTINEL)
            if next_node is SENTINEL:
                next_node = node[k] = TreeCacheNode()
            elif not isinstance(next_node, TreeCacheNode):
                # this suggests that the caller is not being consistent with its key
                # length.
                raise ValueError("value conflicts with an existing subtree")
            node = next_node

        node[key[-1]] = value
        self.size += 1

    def get(self, key, default=None):
        """When `key` is a full key, fetches the value for the given key (if
        any).

        If `key` is only a partial key (i.e. a truncated tuple) then returns a
        `TreeCacheNode`, which can be passed to the `iterate_tree_cache_*`
        functions to iterate over all entries in the cache with keys that start
        with the given partial key.
        """

        node = self.root
        for k in key[:-1]:
            node = node.get(k, None)
            if node is None:
                return default
        return node.get(key[-1], default)

    def clear(self) -> None:
        self.size = 0
        self.root = TreeCacheNode()

    def pop(self, key, default=None):
        """Remove the given key, or subkey, from the cache

        Args:
            key: key or subkey to remove.
            default: value to return if key is not found

        Returns:
            If the key is not found, 'default'. If the key is complete, the removed
            value. If the key is partial, the TreeCacheNode corresponding to the part
            of the tree that was removed.
        """
        if not isinstance(key, tuple):
            raise TypeError("The cache key must be a tuple not %r" % (type(key),))

        # a list of the nodes we have touched on the way down the tree
        nodes = []

        node = self.root
        for k in key[:-1]:
            node = node.get(k, None)
            if node is None:
                return default
            if not isinstance(node, TreeCacheNode):
                # we've gone off the end of the tree
                raise ValueError("pop() key too long")
            nodes.append(node)  # don't add the root node
        popped = node.pop(key[-1], SENTINEL)
        if popped is SENTINEL:
            return default

        # working back up the tree, clear out any nodes that are now empty
        node_and_keys = list(zip(nodes, key))
        node_and_keys.reverse()
        node_and_keys.append((self.root, None))

        for i in range(len(node_and_keys) - 1):
            n, k = node_and_keys[i]

            if n:
                break
            # found an empty node: remove it from its parent, and loop.
            node_and_keys[i + 1][0].pop(k)

        cnt = sum(1 for _ in iterate_tree_cache_entry(popped))
        self.size -= cnt
        return popped

    def values(self):
        return iterate_tree_cache_entry(self.root)

    def items(self):
        return iterate_tree_cache_items((), self.root)

    def __len__(self) -> int:
        return self.size


def iterate_tree_cache_entry(d):
    """Helper function to iterate over the leaves of a tree, i.e. a dict of that
    can contain dicts.
    """
    if isinstance(d, TreeCacheNode):
        for value_d in d.values():
            yield from iterate_tree_cache_entry(value_d)
    else:
        yield d


def iterate_tree_cache_items(key, value):
    """Helper function to iterate over the leaves of a tree, i.e. a dict of that
    can contain dicts.

    The provided key is a tuple that will get prepended to the returned keys.

    Example:

        cache = TreeCache()
        cache[(1, 1)] = "a"
        cache[(1, 2)] = "b"
        cache[(2, 1)] = "c"

        tree_node = cache.get((1,))

        items = iterate_tree_cache_items((1,), tree_node)
        assert list(items) == [((1, 1), "a"), ((1, 2), "b")]

    Returns:
        A generator yielding key/value pairs.
    """
    if isinstance(value, TreeCacheNode):
        for sub_key, sub_value in value.items():
            yield from iterate_tree_cache_items((*key, sub_key), sub_value)
    else:
        # we've reached a leaf of the tree.
        yield key, value