diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 597 |
1 files changed, 343 insertions, 254 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 5673e4aa96..89a05c4618 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,16 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches import intern_string -from synapse.util.stringutils import to_ascii -from synapse.storage.engines import PostgresEngine +import logging +from collections import namedtuple + +from six import iteritems, itervalues +from six.moves import range from twisted.internet import defer -from collections import namedtuple -import logging +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.engines import PostgresEngine +from synapse.util.caches import get_cache_factor_for, intern_string +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.stringutils import to_ascii + +from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -40,45 +46,19 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt return len(self.delta_ids) if self.delta_ids else 0 -class StateStore(SQLBaseStore): - """ Keeps track of the state at a given event. - - This is done by the concept of `state groups`. Every event is a assigned - a state group (identified by an arbitrary string), which references a - collection of state events. The current state of an event is then the - collection of state events referenced by the event's state group. - - Hence, every change in the current state causes a new state group to be - generated. However, if no change happens (e.g., if we get a message event - with only one parent it inherits the state group from its parent.) - - There are three tables: - * `state_groups`: Stores group name, first event with in the group and - room id. - * `event_to_state_groups`: Maps events to state groups. - * `state_groups_state`: Maps state group to state events. +class StateGroupWorkerStore(SQLBaseStore): + """The parts of StateGroupStore that can be called from workers. """ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, hs): - super(StateStore, self).__init__(hs) - self.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, - self._background_index_state, - ) - self.register_background_index_update( - self.CURRENT_STATE_INDEX_UPDATE_NAME, - index_name="current_state_events_member_index", - table="current_state_events", - columns=["state_key"], - where_clause="type='m.room.member'", + def __init__(self, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(db_conn, hs) + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache") ) @cached(max_entries=100000, iterable=True) @@ -158,12 +138,26 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.itervalues()) + groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups) defer.returnValue(group_to_state) @defer.inlineCallbacks + def get_state_ids_for_group(self, state_group): + """Get the state IDs for the given state group + + Args: + state_group (int) + + Returns: + Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = yield self._get_state_for_groups((state_group,)) + + defer.returnValue(group_to_state[state_group]) + + @defer.inlineCallbacks def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids @@ -176,199 +170,27 @@ class StateStore(SQLBaseStore): state_event_map = yield self.get_events( [ - ev_id for group_ids in group_to_ids.itervalues() - for ev_id in group_ids.itervalues() + ev_id for group_ids in itervalues(group_to_ids) + for ev_id in itervalues(group_ids) ], get_prev_content=False ) defer.returnValue({ group: [ - state_event_map[v] for v in event_id_map.itervalues() + state_event_map[v] for v in itervalues(event_id_map) if v in state_event_map ] - for group, event_id_map in group_to_ids.iteritems() + for group, event_id_map in iteritems(group_to_ids) }) - def _have_persisted_state_group_txn(self, txn, state_group): - txn.execute( - "SELECT count(*) FROM state_groups WHERE id = ?", - (state_group,) - ) - row = txn.fetchone() - return row and row[0] - - def _store_mult_state_groups_txn(self, txn, events_and_contexts): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - - if context.current_state_ids is None: - # AFAIK, this can never happen - logger.error( - "Non-outlier event %s had current_state_ids==None", - event.event_id) - continue - - # if the event was rejected, just give it the same state as its - # predecessor. - if context.rejected: - state_groups[event.event_id] = context.prev_group - continue - - state_groups[event.event_id] = context.state_group - - if self._have_persisted_state_group_txn(txn, context.state_group): - continue - - self._simple_insert_txn( - txn, - table="state_groups", - values={ - "id": context.state_group, - "room_id": event.room_id, - "event_id": event.event_id, - }, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if context.prev_group: - is_in_db = self._simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": context.prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (context.prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn( - txn, context.prev_group - ) - if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": context.state_group, - "prev_state_group": context.prev_group, - }, - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in context.delta_ids.iteritems() - ], - ) - else: - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in context.current_state_ids.iteritems() - ], - ) - - # Prefill the state group cache with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=context.state_group, - value=dict(context.current_state_ids), - full=True, - ) - - self._simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - { - "state_group": state_group_id, - "event_id": event_id, - } - for event_id, state_group_id in state_groups.iteritems() - ], - ) - - for event_id, state_group_id in state_groups.iteritems(): - txn.call_after( - self._get_state_group_for_event.prefill, - (event_id,), state_group_id - ) - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = (""" - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """) - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - @defer.inlineCallbacks def _get_state_groups_from_groups(self, groups, types): """Returns dictionary state_group -> (dict of (type, state_key) -> event id) """ results = {} - chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)] + chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", @@ -422,6 +244,9 @@ class StateStore(SQLBaseStore): ( "AND type = ? AND state_key = ?", (etype, state_key) + ) if state_key is not None else ( + "AND type = ?", + (etype,) ) for etype, state_key in types ] @@ -441,10 +266,19 @@ class StateStore(SQLBaseStore): key = (typ, state_key) results[group][key] = event_id else: + where_args = [] + where_clauses = [] + wildcard_types = False if types is not None: - where_clause = "AND (%s)" % ( - " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), - ) + for typ in types: + if typ[1] is None: + where_clauses.append("(type = ?)") + where_args.append(typ[0]) + wildcard_types = True + else: + where_clauses.append("(type = ? AND state_key = ?)") + where_args.extend([typ[0], typ[1]]) + where_clause = "AND (%s)" % (" OR ".join(where_clauses)) else: where_clause = "" @@ -461,7 +295,7 @@ class StateStore(SQLBaseStore): # after we finish deduping state, which requires this func) args = [next_group] if types: - args.extend(i for typ in types for i in typ) + args.extend(where_args) txn.execute( "SELECT type, state_key, event_id FROM state_groups_state" @@ -474,9 +308,17 @@ class StateStore(SQLBaseStore): if (typ, state_key) not in results[group] ) - # If the lengths match then we must have all the types, - # so no need to go walk further down the tree. - if types is not None and len(results[group]) == len(types): + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + types is not None and + not wildcard_types and + len(results[group]) == len(types) + ): break next_group = self._simple_select_one_onecol_txn( @@ -509,21 +351,21 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.itervalues()) + groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups, types) state_event_map = yield self.get_events( - [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()], + [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], get_prev_content=False ) event_to_state = { event_id: { k: state_event_map[v] - for k, v in group_to_state[group].iteritems() + for k, v in iteritems(group_to_state[group]) if v in state_event_map } - for event_id, group in event_to_groups.iteritems() + for event_id, group in iteritems(event_to_groups) } defer.returnValue({event: event_to_state[event] for event in event_ids}) @@ -546,12 +388,12 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.itervalues()) + groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups, types) event_to_state = { event_id: group_to_state[group] - for event_id, group in event_to_groups.iteritems() + for event_id, group in iteritems(event_to_groups) } defer.returnValue({event: event_to_state[event] for event in event_ids}) @@ -665,7 +507,7 @@ class StateStore(SQLBaseStore): got_all = is_all or not missing_types return { - k: v for k, v in state_dict_ids.iteritems() + k: v for k, v in iteritems(state_dict_ids) if include(k[0], k[1]) }, missing_types, got_all @@ -685,10 +527,23 @@ class StateStore(SQLBaseStore): @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): - """Given list of groups returns dict of group -> list of state events - with matching types. `types` is a list of `(type, state_key)`, where - a `state_key` of None matches all state_keys. If `types` is None then - all events are returned. + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + types (None|iterable[(str, None|str)]): + indicates the state type/keys required. If None, the whole + state is fetched and returned. + + Otherwise, each entry should be a `(type, state_key)` tuple to + include in the response. A `state_key` of None is a wildcard + meaning that we require all state with that type. + + Returns: + Deferred[dict[int, dict[(type, state_key), EventBase]]] + a dictionary mapping from state group to state dictionary. """ if types: types = frozenset(types) @@ -697,7 +552,7 @@ class StateStore(SQLBaseStore): if types is not None: for group in set(groups): state_dict_ids, _, got_all = self._get_some_state_from_cache( - group, types + group, types, ) results[group] = state_dict_ids @@ -718,32 +573,266 @@ class StateStore(SQLBaseStore): # Okay, so we have some missing_types, lets fetch them. cache_seq_num = self._state_group_cache.sequence + # the DictionaryCache knows if it has *all* the state, but + # does not know if it has all of the keys of a particular type, + # which makes wildcard lookups expensive unless we have a complete + # cache. Hence, if we are doing a wildcard lookup, populate the + # cache fully so that we can do an efficient lookup next time. + + if types and any(k is None for (t, k) in types): + types_to_fetch = None + else: + types_to_fetch = types + group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types + missing_groups, types_to_fetch, ) - # Now we want to update the cache with all the things we fetched - # from the database. - for group, group_state_dict in group_to_state_dict.iteritems(): + for group, group_state_dict in iteritems(group_to_state_dict): state_dict = results[group] - state_dict.update( - ((intern_string(k[0]), intern_string(k[1])), to_ascii(v)) - for k, v in group_state_dict.iteritems() - ) - + # update the result, filtering by `types`. + if types: + for k, v in iteritems(group_state_dict): + (typ, _) = k + if k in types or (typ, None) in types: + state_dict[k] = v + else: + state_dict.update(group_state_dict) + + # update the cache with all the things we fetched from the + # database. self._state_group_cache.update( cache_seq_num, key=group, - value=state_dict, - full=(types is None), - known_absent=types, + value=group_state_dict, + fetched_keys=types_to_fetch, ) defer.returnValue(results) - def get_next_state_group(self): - return self._state_groups_id_gen.get_next() + def store_state_group(self, event_id, room_id, prev_group, delta_ids, + current_state_ids): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self.database_engine.get_next_state_group_id(txn) + + self._simple_insert_txn( + txn, + table="state_groups", + values={ + "id": state_group, + "room_id": room_id, + "event_id": event_id, + }, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self._simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn( + txn, prev_group + ) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self._simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_ids) + ], + ) + else: + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(current_state_ids) + ], + ) + + # Prefill the state group cache with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_state_ids), + ) + + return state_group + + return self.runInteraction("store_state_group", _store_state_group_txn) + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = (""" + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """) + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + +class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + + def __init__(self, db_conn, hs): + super(StateStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, + self._background_index_state, + ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) + + def _store_event_state_mappings_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.prev_group + continue + + state_groups[event.event_id] = context.state_group + + self._simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + { + "state_group": state_group_id, + "event_id": event_id, + } + for event_id, state_group_id in iteritems(state_groups) + ], + ) + + for event_id, state_group_id in iteritems(state_groups): + txn.call_after( + self._get_state_group_for_event.prefill, + (event_id,), state_group_id + ) @defer.inlineCallbacks def _background_deduplicate_state(self, progress, batch_size): @@ -767,7 +856,7 @@ class StateStore(SQLBaseStore): def reindex_txn(txn): new_last_state_group = last_state_group - for count in xrange(batch_size): + for count in range(batch_size): txn.execute( "SELECT id, room_id FROM state_groups" " WHERE ? < id AND id <= ?" @@ -825,7 +914,7 @@ class StateStore(SQLBaseStore): # of keys delta_state = { - key: value for key, value in curr_state.iteritems() + key: value for key, value in iteritems(curr_state) if prev_state.get(key, None) != value } @@ -865,7 +954,7 @@ class StateStore(SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in delta_state.iteritems() + for key, state_id in iteritems(delta_state) ], ) |