summary refs log tree commit diff
path: root/synapse/storage/_base.py
blob: fe4a76341137a1df6c3a4f2e2418ae66c7fee011 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from abc import ABCMeta
from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union

from synapse.storage.database import make_in_list_sql_clause  # noqa: F401; noqa: F401
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import get_domain_from_id
from synapse.util import json_decoder
from synapse.util.caches.descriptors import CachedFunction

if TYPE_CHECKING:
    from synapse.server import HomeServer

logger = logging.getLogger(__name__)


# some of our subclasses have abstract methods, so we use the ABCMeta metaclass.
class SQLBaseStore(metaclass=ABCMeta):
    """Base class for data stores that holds helper functions.

    Note that multiple instances of this class will exist as there will be one
    per data store (and not one per physical database).
    """

    db_pool: DatabasePool

    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        self.hs = hs
        self._clock = hs.get_clock()
        self.database_engine = database.engine
        self.db_pool = database

        self.external_cached_functions: Dict[str, CachedFunction] = {}

    def process_replication_rows(  # noqa: B027 (no-op by design)
        self,
        stream_name: str,
        instance_name: str,
        token: int,
        rows: Iterable[Any],
    ) -> None:
        """
        Used by storage classes to invalidate caches based on incoming replication data. These
        must not update any ID generators, use `process_replication_position`.
        """

    def process_replication_position(  # noqa: B027 (no-op by design)
        self,
        stream_name: str,
        instance_name: str,
        token: int,
    ) -> None:
        """
        Used by storage classes to advance ID generators based on incoming replication data. This
        is called after process_replication_rows such that caches are invalidated before any token
        positions advance.
        """

    def _invalidate_state_caches(
        self, room_id: str, members_changed: Collection[str]
    ) -> None:
        """Invalidates caches that are based on the current state, but does
        not stream invalidations down replication.

        Args:
            room_id: Room where state changed
            members_changed: The user_ids of members that have changed
        """

        # XXX: If you add something to this function make sure you add it to
        # `_invalidate_state_caches_all` as well.

        # If there were any membership changes, purge the appropriate caches.
        for host in {get_domain_from_id(u) for u in members_changed}:
            self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
            self._attempt_to_invalidate_cache("is_host_invited", (room_id, host))
        if members_changed:
            self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
            self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,))
            self._attempt_to_invalidate_cache(
                "get_users_in_room_with_profiles", (room_id,)
            )
            self._attempt_to_invalidate_cache(
                "get_number_joined_users_in_room", (room_id,)
            )
            self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,))

            # There's no easy way of invalidating this cache for just the users
            # that have changed, so we just clear the entire thing.
            self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None)

        for user_id in members_changed:
            self._attempt_to_invalidate_cache(
                "get_user_in_room_with_profile", (room_id, user_id)
            )
            self._attempt_to_invalidate_cache(
                "get_rooms_for_user_with_stream_ordering", (user_id,)
            )
            self._attempt_to_invalidate_cache("get_rooms_for_user", (user_id,))

        # Purge other caches based on room state.
        self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
        self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))

    def _invalidate_state_caches_all(self, room_id: str) -> None:
        """Invalidates caches that are based on the current state, but does
        not stream invalidations down replication.

        Same as `_invalidate_state_caches`, except that works when we don't know
        which memberships have changed.

        Args:
            room_id: Room where state changed
        """
        self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
        self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
        self._attempt_to_invalidate_cache("is_host_invited", None)
        self._attempt_to_invalidate_cache("is_host_joined", None)
        self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,))
        self._attempt_to_invalidate_cache("get_users_in_room_with_profiles", (room_id,))
        self._attempt_to_invalidate_cache("get_number_joined_users_in_room", (room_id,))
        self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,))
        self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None)
        self._attempt_to_invalidate_cache("get_user_in_room_with_profile", None)
        self._attempt_to_invalidate_cache(
            "get_rooms_for_user_with_stream_ordering", None
        )
        self._attempt_to_invalidate_cache("get_rooms_for_user", None)
        self._attempt_to_invalidate_cache("get_room_summary", (room_id,))

    def _attempt_to_invalidate_cache(
        self, cache_name: str, key: Optional[Collection[Any]]
    ) -> bool:
        """Attempts to invalidate the cache of the given name, ignoring if the
        cache doesn't exist. Mainly used for invalidating caches on workers,
        where they may not have the cache.

        Note that this function does not invalidate any remote caches, only the
        local in-memory ones. Any remote invalidation must be performed before
        calling this.

        Args:
            cache_name
            key: Entry to invalidate. If None then invalidates the entire
                cache.
        """

        try:
            cache = getattr(self, cache_name)
        except AttributeError:
            # Check if an externally defined module cache has been registered
            cache = self.external_cached_functions.get(cache_name)
            if not cache:
                # We probably haven't pulled in the cache in this worker,
                # which is fine.
                return False

        if key is None:
            cache.invalidate_all()
        else:
            # Prefer any local-only invalidation method. Invalidating any non-local
            # cache must be be done before this.
            invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
            invalidate_method(tuple(key))

        return True

    def register_external_cached_function(
        self, cache_name: str, func: CachedFunction
    ) -> None:
        self.external_cached_functions[cache_name] = func


def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
    """
    Take some data from a database row and return a JSON-decoded object.

    Args:
        db_content: The JSON-encoded contents from the database.

    Returns:
        The object decoded from JSON.
    """
    # psycopg2 on Python 3 returns memoryview objects, which we need to
    # cast to bytes to decode
    if isinstance(db_content, memoryview):
        db_content = db_content.tobytes()

    # Decode it to a Unicode string before feeding it to the JSON decoder, since
    # it only supports handling strings
    if isinstance(db_content, (bytes, bytearray)):
        db_content = db_content.decode("utf8")

    try:
        return json_decoder.decode(db_content)
    except Exception:
        logging.warning("Tried to decode '%r' as JSON and failed", db_content)
        raise