diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 274 |
1 files changed, 156 insertions, 118 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 360e3e4355..ffa4246031 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt return len(self.delta_ids) if self.delta_ids else 0 -class StateGroupReadStore(SQLBaseStore): - """The read-only parts of StateGroupStore - - None of these functions write to the state tables, so are suitable for - including in the SlavedStores. +class StateGroupWorkerStore(SQLBaseStore): + """The parts of StateGroupStore that can be called from workers. """ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" @@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore): CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" def __init__(self, db_conn, hs): - super(StateGroupReadStore, self).__init__(db_conn, hs) + super(StateGroupWorkerStore, self).__init__(db_conn, hs) self._state_group_cache = DictionaryCache( "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR @@ -143,6 +140,20 @@ class StateGroupReadStore(SQLBaseStore): defer.returnValue(group_to_state) @defer.inlineCallbacks + def get_state_ids_for_group(self, state_group): + """Get the state IDs for the given state group + + Args: + state_group (int) + + Returns: + Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = yield self._get_state_for_groups((state_group,)) + + defer.returnValue(group_to_state[state_group]) + + @defer.inlineCallbacks def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids @@ -229,6 +240,9 @@ class StateGroupReadStore(SQLBaseStore): ( "AND type = ? AND state_key = ?", (etype, state_key) + ) if state_key is not None else ( + "AND type = ?", + (etype,) ) for etype, state_key in types ] @@ -248,10 +262,19 @@ class StateGroupReadStore(SQLBaseStore): key = (typ, state_key) results[group][key] = event_id else: + where_args = [] + where_clauses = [] + wildcard_types = False if types is not None: - where_clause = "AND (%s)" % ( - " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), - ) + for typ in types: + if typ[1] is None: + where_clauses.append("(type = ?)") + where_args.extend(typ[0]) + wildcard_types = True + else: + where_clauses.append("(type = ? AND state_key = ?)") + where_args.extend([typ[0], typ[1]]) + where_clause = "AND (%s)" % (" OR ".join(where_clauses)) else: where_clause = "" @@ -268,7 +291,7 @@ class StateGroupReadStore(SQLBaseStore): # after we finish deduping state, which requires this func) args = [next_group] if types: - args.extend(i for typ in types for i in typ) + args.extend(where_args) txn.execute( "SELECT type, state_key, event_id FROM state_groups_state" @@ -281,9 +304,17 @@ class StateGroupReadStore(SQLBaseStore): if (typ, state_key) not in results[group] ) - # If the lengths match then we must have all the types, - # so no need to go walk further down the tree. - if types is not None and len(results[group]) == len(types): + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + types is not None and + not wildcard_types and + len(results[group]) == len(types) + ): break next_group = self._simple_select_one_onecol_txn( @@ -549,116 +580,66 @@ class StateGroupReadStore(SQLBaseStore): defer.returnValue(results) + def store_state_group(self, event_id, room_id, prev_group, delta_ids, + current_state_ids): + """Store a new set of state, returning a newly assigned state group. -class StateStore(StateGroupReadStore, BackgroundUpdateStore): - """ Keeps track of the state at a given event. - - This is done by the concept of `state groups`. Every event is a assigned - a state group (identified by an arbitrary string), which references a - collection of state events. The current state of an event is then the - collection of state events referenced by the event's state group. - - Hence, every change in the current state causes a new state group to be - generated. However, if no change happens (e.g., if we get a message event - with only one parent it inherits the state group from its parent.) - - There are three tables: - * `state_groups`: Stores group name, first event with in the group and - room id. - * `event_to_state_groups`: Maps events to state groups. - * `state_groups_state`: Maps state group to state events. - """ - - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - - def __init__(self, db_conn, hs): - super(StateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, - self._background_index_state, - ) - self.register_background_index_update( - self.CURRENT_STATE_INDEX_UPDATE_NAME, - index_name="current_state_events_member_index", - table="current_state_events", - columns=["state_key"], - where_clause="type='m.room.member'", - ) - - def _have_persisted_state_group_txn(self, txn, state_group): - txn.execute( - "SELECT count(*) FROM state_groups WHERE id = ?", - (state_group,) - ) - row = txn.fetchone() - return row and row[0] - - def _store_mult_state_groups_txn(self, txn, events_and_contexts): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. - if context.current_state_ids is None: + Returns: + Deferred[int]: The state group ID + """ + def _store_state_group_txn(txn): + if current_state_ids is None: # AFAIK, this can never happen - logger.error( - "Non-outlier event %s had current_state_ids==None", - event.event_id) - continue - - # if the event was rejected, just give it the same state as its - # predecessor. - if context.rejected: - state_groups[event.event_id] = context.prev_group - continue + raise Exception("current_state_ids cannot be None") - state_groups[event.event_id] = context.state_group - - if self._have_persisted_state_group_txn(txn, context.state_group): - continue + state_group = self.database_engine.get_next_state_group_id(txn) self._simple_insert_txn( txn, table="state_groups", values={ - "id": context.state_group, - "room_id": event.room_id, - "event_id": event.event_id, + "id": state_group, + "room_id": room_id, + "event_id": event_id, }, ) # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. - if context.prev_group: + if prev_group: is_in_db = self._simple_select_one_onecol_txn( txn, table="state_groups", - keyvalues={"id": context.prev_group}, + keyvalues={"id": prev_group}, retcol="id", allow_none=True, ) if not is_in_db: raise Exception( "Trying to persist state with unpersisted prev_group: %r" - % (context.prev_group,) + % (prev_group,) ) potential_hops = self._count_state_group_hops_txn( - txn, context.prev_group + txn, prev_group ) - if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: self._simple_insert_txn( txn, table="state_group_edges", values={ - "state_group": context.state_group, - "prev_state_group": context.prev_group, + "state_group": state_group, + "prev_state_group": prev_group, }, ) @@ -667,13 +648,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore): table="state_groups_state", values=[ { - "state_group": context.state_group, - "room_id": event.room_id, + "state_group": state_group, + "room_id": room_id, "type": key[0], "state_key": key[1], "event_id": state_id, } - for key, state_id in context.delta_ids.iteritems() + for key, state_id in delta_ids.iteritems() ], ) else: @@ -682,13 +663,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore): table="state_groups_state", values=[ { - "state_group": context.state_group, - "room_id": event.room_id, + "state_group": state_group, + "room_id": room_id, "type": key[0], "state_key": key[1], "event_id": state_id, } - for key, state_id in context.current_state_ids.iteritems() + for key, state_id in current_state_ids.iteritems() ], ) @@ -699,28 +680,14 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore): txn.call_after( self._state_group_cache.update, self._state_group_cache.sequence, - key=context.state_group, - value=dict(context.current_state_ids), + key=state_group, + value=dict(current_state_ids), full=True, ) - self._simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - { - "state_group": state_group_id, - "event_id": event_id, - } - for event_id, state_group_id in state_groups.iteritems() - ], - ) + return state_group - for event_id, state_group_id in state_groups.iteritems(): - txn.call_after( - self._get_state_group_for_event.prefill, - (event_id,), state_group_id - ) + return self.runInteraction("store_state_group", _store_state_group_txn) def _count_state_group_hops_txn(self, txn, state_group): """Given a state group, count how many hops there are in the tree. @@ -763,8 +730,79 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore): return count - def get_next_state_group(self): - return self._state_groups_id_gen.get_next() + +class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + + def __init__(self, db_conn, hs): + super(StateStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, + self._background_index_state, + ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) + + def _store_event_state_mappings_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.prev_group + continue + + state_groups[event.event_id] = context.state_group + + self._simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + { + "state_group": state_group_id, + "event_id": event_id, + } + for event_id, state_group_id in state_groups.iteritems() + ], + ) + + for event_id, state_group_id in state_groups.iteritems(): + txn.call_after( + self._get_state_group_for_event.prefill, + (event_id,), state_group_id + ) @defer.inlineCallbacks def _background_deduplicate_state(self, progress, batch_size): |