summary refs log tree commit diff
path: root/synapse/storage/state.py
blob: a2df8fa8272af3d26cbc53d4373a2a182ec4bb02 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from six import iteritems, itervalues

import attr

from synapse.api.constants import EventTypes

logger = logging.getLogger(__name__)


@attr.s(slots=True)
class StateFilter(object):
    """A filter used when querying for state.

    Attributes:
        types (dict[str, set[str]|None]): Map from type to set of state keys (or
            None). This specifies which state_keys for the given type to fetch
            from the DB. If None then all events with that type are fetched. If
            the set is empty then no events with that type are fetched.
        include_others (bool): Whether to fetch events with types that do not
            appear in `types`.
    """

    types = attr.ib()
    include_others = attr.ib(default=False)

    def __attrs_post_init__(self):
        # If `include_others` is set we canonicalise the filter by removing
        # wildcards from the types dictionary
        if self.include_others:
            self.types = {k: v for k, v in iteritems(self.types) if v is not None}

    @staticmethod
    def all():
        """Creates a filter that fetches everything.

        Returns:
            StateFilter
        """
        return StateFilter(types={}, include_others=True)

    @staticmethod
    def none():
        """Creates a filter that fetches nothing.

        Returns:
            StateFilter
        """
        return StateFilter(types={}, include_others=False)

    @staticmethod
    def from_types(types):
        """Creates a filter that only fetches the given types

        Args:
            types (Iterable[tuple[str, str|None]]): A list of type and state
                keys to fetch. A state_key of None fetches everything for
                that type

        Returns:
            StateFilter
        """
        type_dict = {}
        for typ, s in types:
            if typ in type_dict:
                if type_dict[typ] is None:
                    continue

            if s is None:
                type_dict[typ] = None
                continue

            type_dict.setdefault(typ, set()).add(s)

        return StateFilter(types=type_dict)

    @staticmethod
    def from_lazy_load_member_list(members):
        """Creates a filter that returns all non-member events, plus the member
        events for the given users

        Args:
            members (iterable[str]): Set of user IDs

        Returns:
            StateFilter
        """
        return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)

    def return_expanded(self):
        """Creates a new StateFilter where type wild cards have been removed
        (except for memberships). The returned filter is a superset of the
        current one, i.e. anything that passes the current filter will pass
        the returned filter.

        This helps the caching as the DictionaryCache knows if it has *all* the
        state, but does not know if it has all of the keys of a particular type,
        which makes wildcard lookups expensive unless we have a complete cache.
        Hence, if we are doing a wildcard lookup, populate the cache fully so
        that we can do an efficient lookup next time.

        Note that since we have two caches, one for membership events and one for
        other events, we can be a bit more clever than simply returning
        `StateFilter.all()` if `has_wildcards()` is True.

        We return a StateFilter where:
            1. the list of membership events to return is the same
            2. if there is a wildcard that matches non-member events we
               return all non-member events

        Returns:
            StateFilter
        """

        if self.is_full():
            # If we're going to return everything then there's nothing to do
            return self

        if not self.has_wildcards():
            # If there are no wild cards, there's nothing to do
            return self

        if EventTypes.Member in self.types:
            get_all_members = self.types[EventTypes.Member] is None
        else:
            get_all_members = self.include_others

        has_non_member_wildcard = self.include_others or any(
            state_keys is None
            for t, state_keys in iteritems(self.types)
            if t != EventTypes.Member
        )

        if not has_non_member_wildcard:
            # If there are no non-member wild cards we can just return ourselves
            return self

        if get_all_members:
            # We want to return everything.
            return StateFilter.all()
        else:
            # We want to return all non-members, but only particular
            # memberships
            return StateFilter(
                types={EventTypes.Member: self.types[EventTypes.Member]},
                include_others=True,
            )

    def make_sql_filter_clause(self):
        """Converts the filter to an SQL clause.

        For example:

            f = StateFilter.from_types([("m.room.create", "")])
            clause, args = f.make_sql_filter_clause()
            clause == "(type = ? AND state_key = ?)"
            args == ['m.room.create', '']


        Returns:
            tuple[str, list]: The SQL string (may be empty) and arguments. An
            empty SQL string is returned when the filter matches everything
            (i.e. is "full").
        """

        where_clause = ""
        where_args = []

        if self.is_full():
            return where_clause, where_args

        if not self.include_others and not self.types:
            # i.e. this is an empty filter, so we need to return a clause that
            # will match nothing
            return "1 = 2", []

        # First we build up a lost of clauses for each type/state_key combo
        clauses = []
        for etype, state_keys in iteritems(self.types):
            if state_keys is None:
                clauses.append("(type = ?)")
                where_args.append(etype)
                continue

            for state_key in state_keys:
                clauses.append("(type = ? AND state_key = ?)")
                where_args.extend((etype, state_key))

        # This will match anything that appears in `self.types`
        where_clause = " OR ".join(clauses)

        # If we want to include stuff that's not in the types dict then we add
        # a `OR type NOT IN (...)` clause to the end.
        if self.include_others:
            if where_clause:
                where_clause += " OR "

            where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
            where_args.extend(self.types)

        return where_clause, where_args

    def max_entries_returned(self):
        """Returns the maximum number of entries this filter will return if
        known, otherwise returns None.

        For example a simple state filter asking for `("m.room.create", "")`
        will return 1, whereas the default state filter will return None.

        This is used to bail out early if the right number of entries have been
        fetched.
        """
        if self.has_wildcards():
            return None

        return len(self.concrete_types())

    def filter_state(self, state_dict):
        """Returns the state filtered with by this StateFilter

        Args:
            state (dict[tuple[str, str], Any]): The state map to filter

        Returns:
            dict[tuple[str, str], Any]: The filtered state map
        """
        if self.is_full():
            return dict(state_dict)

        filtered_state = {}
        for k, v in iteritems(state_dict):
            typ, state_key = k
            if typ in self.types:
                state_keys = self.types[typ]
                if state_keys is None or state_key in state_keys:
                    filtered_state[k] = v
            elif self.include_others:
                filtered_state[k] = v

        return filtered_state

    def is_full(self):
        """Whether this filter fetches everything or not

        Returns:
            bool
        """
        return self.include_others and not self.types

    def has_wildcards(self):
        """Whether the filter includes wildcards or is attempting to fetch
        specific state.

        Returns:
            bool
        """

        return self.include_others or any(
            state_keys is None for state_keys in itervalues(self.types)
        )

    def concrete_types(self):
        """Returns a list of concrete type/state_keys (i.e. not None) that
        will be fetched. This will be a complete list if `has_wildcards`
        returns False, but otherwise will be a subset (or even empty).

        Returns:
            list[tuple[str,str]]
        """
        return [
            (t, s)
            for t, state_keys in iteritems(self.types)
            if state_keys is not None
            for s in state_keys
        ]

    def get_member_split(self):
        """Return the filter split into two: one which assumes it's exclusively
        matching against member state, and one which assumes it's matching
        against non member state.

        This is useful due to the returned filters giving correct results for
        `is_full()`, `has_wildcards()`, etc, when operating against maps that
        either exclusively contain member events or only contain non-member
        events. (Which is the case when dealing with the member vs non-member
        state caches).

        Returns:
            tuple[StateFilter, StateFilter]: The member and non member filters
        """

        if EventTypes.Member in self.types:
            state_keys = self.types[EventTypes.Member]
            if state_keys is None:
                member_filter = StateFilter.all()
            else:
                member_filter = StateFilter({EventTypes.Member: state_keys})
        elif self.include_others:
            member_filter = StateFilter.all()
        else:
            member_filter = StateFilter.none()

        non_member_filter = StateFilter(
            types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member},
            include_others=self.include_others,
        )

        return member_filter, non_member_filter