summary refs log tree commit diff
path: root/synapse/util/linked_list.py
blob: e9a5fff2118b7d539a0723a703e538a62857300b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#

"""A circular doubly linked list implementation.
"""

import threading
from typing import Generic, Optional, Type, TypeVar

P = TypeVar("P")
LN = TypeVar("LN", bound="ListNode")


class ListNode(Generic[P]):
    """A node in a circular doubly linked list, with an (optional) reference to
    a cache entry.

    The reference should only be `None` for the root node or if the node has
    been removed from the list.
    """

    # A lock to protect mutating the list prev/next pointers.
    _LOCK = threading.Lock()

    # We don't use attrs here as in py3.6 you can't have `attr.s(slots=True)`
    # and inherit from `Generic` for some reason
    __slots__ = [
        "cache_entry",
        "prev_node",
        "next_node",
    ]

    def __init__(self, cache_entry: Optional[P] = None) -> None:
        self.cache_entry = cache_entry
        self.prev_node: Optional[ListNode[P]] = None
        self.next_node: Optional[ListNode[P]] = None

    @classmethod
    def create_root_node(cls: Type["ListNode[P]"]) -> "ListNode[P]":
        """Create a new linked list by creating a "root" node, which is a node
        that has prev_node/next_node pointing to itself and no associated cache
        entry.
        """
        root = cls()
        root.prev_node = root
        root.next_node = root
        return root

    @classmethod
    def insert_after(
        cls: Type[LN],
        cache_entry: P,
        node: "ListNode[P]",
    ) -> LN:
        """Create a new list node that is placed after the given node.

        Args:
            cache_entry: The associated cache entry.
            node: The existing node in the list to insert the new entry after.
        """
        new_node = cls(cache_entry)
        with cls._LOCK:
            new_node._refs_insert_after(node)
        return new_node

    def remove_from_list(self) -> None:
        """Remove this node from the list."""
        with self._LOCK:
            self._refs_remove_node_from_list()

        # We drop the reference to the cache entry to break the reference cycle
        # between the list node and cache entry, allowing the two to be dropped
        # immediately rather than at the next GC.
        self.cache_entry = None

    def move_after(self, node: "ListNode[P]") -> None:
        """Move this node from its current location in the list to after the
        given node.
        """
        with self._LOCK:
            # We assert that both this node and the target node is still "alive".
            assert self.prev_node
            assert self.next_node
            assert node.prev_node
            assert node.next_node

            assert self is not node

            # Remove self from the list
            self._refs_remove_node_from_list()

            # Insert self back into the list, after target node
            self._refs_insert_after(node)

    def _refs_remove_node_from_list(self) -> None:
        """Internal method to *just* remove the node from the list, without
        e.g. clearing out the cache entry.
        """
        if self.prev_node is None or self.next_node is None:
            # We've already been removed from the list.
            return

        prev_node = self.prev_node
        next_node = self.next_node

        prev_node.next_node = next_node
        next_node.prev_node = prev_node

        # We set these to None so that we don't get circular references,
        # allowing us to be dropped without having to go via the GC.
        self.prev_node = None
        self.next_node = None

    def _refs_insert_after(self, node: "ListNode[P]") -> None:
        """Internal method to insert the node after the given node."""

        # This method should only be called when we're not already in the list.
        assert self.prev_node is None
        assert self.next_node is None

        # We expect the given node to be in the list and thus have valid
        # prev/next refs.
        assert node.next_node
        assert node.prev_node

        prev_node = node
        next_node = node.next_node

        self.prev_node = prev_node
        self.next_node = next_node

        prev_node.next_node = self
        next_node.prev_node = self

    def get_cache_entry(self) -> Optional[P]:
        """Get the cache entry, returns None if this is the root node (i.e.
        cache_entry is None) or if the entry has been dropped.
        """
        return self.cache_entry