summary refs log tree commit diff
path: root/synapse/util/caches/treecache.py
blob: 460e98a92d601d9fbe8ed7435a1e3f5a4cfe6b10 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
SENTINEL = object()


class TreeCache(object):
    """
    Tree-based backing store for LruCache. Allows subtrees of data to be deleted
    efficiently.
    Keys must be tuples.
    """
    def __init__(self):
        self.size = 0
        self.root = {}

    def __setitem__(self, key, value):
        return self.set(key, value)

    def __contains__(self, key):
        return self.get(key, SENTINEL) is not SENTINEL

    def set(self, key, value):
        node = self.root
        for k in key[:-1]:
            node = node.setdefault(k, {})
        node[key[-1]] = _Entry(value)
        self.size += 1

    def get(self, key, default=None):
        node = self.root
        for k in key[:-1]:
            node = node.get(k, None)
            if node is None:
                return default
        return node.get(key[-1], _Entry(default)).value

    def clear(self):
        self.size = 0
        self.root = {}

    def pop(self, key, default=None):
        nodes = []

        node = self.root
        for k in key[:-1]:
            node = node.get(k, None)
            nodes.append(node)  # don't add the root node
            if node is None:
                return default
        popped = node.pop(key[-1], SENTINEL)
        if popped is SENTINEL:
            return default

        node_and_keys = zip(nodes, key)
        node_and_keys.reverse()
        node_and_keys.append((self.root, None))

        for i in range(len(node_and_keys) - 1):
            n, k = node_and_keys[i]

            if n:
                break
            node_and_keys[i + 1][0].pop(k)

        popped, cnt = _strip_and_count_entires(popped)
        self.size -= cnt
        return popped

    def values(self):
        return list(popped_to_iterator(self.root))

    def __len__(self):
        return self.size


def popped_to_iterator(d):
    if isinstance(d, dict):
        for value_d in d.itervalues():
            for value in popped_to_iterator(value_d):
                yield value
    else:
        if isinstance(d, _Entry):
            yield d.value
        else:
            yield d


class _Entry(object):
    __slots__ = ["value"]

    def __init__(self, value):
        self.value = value


def _strip_and_count_entires(d):
    """Takes an _Entry or dict with leaves of _Entry's, and either returns the
    value or a dictionary with _Entry's replaced by their values.

    Also returns the count of _Entry's
    """
    if isinstance(d, dict):
        cnt = 0
        for key, value in d.items():
            v, n = _strip_and_count_entires(value)
            d[key] = v
            cnt += n
        return d, cnt
    else:
        return d.value, 1