diff options
author | Richard van der Hoff <github@rvanderhoff.org.uk> | 2017-11-15 17:23:01 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-11-15 17:23:01 +0000 |
commit | f959c01600d28d5d1a6892f929358ee8ba2b3769 (patch) | |
tree | 0019baf7cb16da23ddafdc19830e0893ec69f34c /synapse/storage/state.py | |
parent | Merge pull request #2677 from matrix-org/rav/spec_r0.3.0 (diff) | |
parent | Pull out bits of StateStore to a mixin (diff) | |
download | synapse-f959c01600d28d5d1a6892f929358ee8ba2b3769.tar.xz |
Merge pull request #2661 from matrix-org/rav/statereadstore
Pull out bits of StateStore to a mixin
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 441 |
1 files changed, 231 insertions, 210 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dd01b68762..360e3e4355 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,16 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches import intern_string -from synapse.util.stringutils import to_ascii -from synapse.storage.engines import PostgresEngine +from collections import namedtuple +import logging from twisted.internet import defer -from collections import namedtuple -import logging +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.engines import PostgresEngine +from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.stringutils import to_ascii +from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -40,23 +42,11 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt return len(self.delta_ids) if self.delta_ids else 0 -class StateStore(SQLBaseStore): - """ Keeps track of the state at a given event. +class StateGroupReadStore(SQLBaseStore): + """The read-only parts of StateGroupStore - This is done by the concept of `state groups`. Every event is a assigned - a state group (identified by an arbitrary string), which references a - collection of state events. The current state of an event is then the - collection of state events referenced by the event's state group. - - Hence, every change in the current state causes a new state group to be - generated. However, if no change happens (e.g., if we get a message event - with only one parent it inherits the state group from its parent.) - - There are three tables: - * `state_groups`: Stores group name, first event with in the group and - room id. - * `event_to_state_groups`: Maps events to state groups. - * `state_groups_state`: Maps state group to state events. + None of these functions write to the state tables, so are suitable for + including in the SlavedStores. """ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" @@ -64,21 +54,10 @@ class StateStore(SQLBaseStore): CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" def __init__(self, db_conn, hs): - super(StateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, - self._background_index_state, - ) - self.register_background_index_update( - self.CURRENT_STATE_INDEX_UPDATE_NAME, - index_name="current_state_events_member_index", - table="current_state_events", - columns=["state_key"], - where_clause="type='m.room.member'", + super(StateGroupReadStore, self).__init__(db_conn, hs) + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR ) @cached(max_entries=100000, iterable=True) @@ -190,178 +169,6 @@ class StateStore(SQLBaseStore): for group, event_id_map in group_to_ids.iteritems() }) - def _have_persisted_state_group_txn(self, txn, state_group): - txn.execute( - "SELECT count(*) FROM state_groups WHERE id = ?", - (state_group,) - ) - row = txn.fetchone() - return row and row[0] - - def _store_mult_state_groups_txn(self, txn, events_and_contexts): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - - if context.current_state_ids is None: - # AFAIK, this can never happen - logger.error( - "Non-outlier event %s had current_state_ids==None", - event.event_id) - continue - - # if the event was rejected, just give it the same state as its - # predecessor. - if context.rejected: - state_groups[event.event_id] = context.prev_group - continue - - state_groups[event.event_id] = context.state_group - - if self._have_persisted_state_group_txn(txn, context.state_group): - continue - - self._simple_insert_txn( - txn, - table="state_groups", - values={ - "id": context.state_group, - "room_id": event.room_id, - "event_id": event.event_id, - }, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if context.prev_group: - is_in_db = self._simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": context.prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (context.prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn( - txn, context.prev_group - ) - if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": context.state_group, - "prev_state_group": context.prev_group, - }, - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in context.delta_ids.iteritems() - ], - ) - else: - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in context.current_state_ids.iteritems() - ], - ) - - # Prefill the state group cache with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=context.state_group, - value=dict(context.current_state_ids), - full=True, - ) - - self._simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - { - "state_group": state_group_id, - "event_id": event_id, - } - for event_id, state_group_id in state_groups.iteritems() - ], - ) - - for event_id, state_group_id in state_groups.iteritems(): - txn.call_after( - self._get_state_group_for_event.prefill, - (event_id,), state_group_id - ) - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = (""" - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """) - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - @defer.inlineCallbacks def _get_state_groups_from_groups(self, groups, types): """Returns dictionary state_group -> (dict of (type, state_key) -> event id) @@ -742,6 +549,220 @@ class StateStore(SQLBaseStore): defer.returnValue(results) + +class StateStore(StateGroupReadStore, BackgroundUpdateStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + + def __init__(self, db_conn, hs): + super(StateStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, + self._background_index_state, + ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) + + def _have_persisted_state_group_txn(self, txn, state_group): + txn.execute( + "SELECT count(*) FROM state_groups WHERE id = ?", + (state_group,) + ) + row = txn.fetchone() + return row and row[0] + + def _store_mult_state_groups_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + + if context.current_state_ids is None: + # AFAIK, this can never happen + logger.error( + "Non-outlier event %s had current_state_ids==None", + event.event_id) + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.prev_group + continue + + state_groups[event.event_id] = context.state_group + + if self._have_persisted_state_group_txn(txn, context.state_group): + continue + + self._simple_insert_txn( + txn, + table="state_groups", + values={ + "id": context.state_group, + "room_id": event.room_id, + "event_id": event.event_id, + }, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if context.prev_group: + is_in_db = self._simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": context.prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (context.prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn( + txn, context.prev_group + ) + if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self._simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": context.state_group, + "prev_state_group": context.prev_group, + }, + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": context.state_group, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in context.delta_ids.iteritems() + ], + ) + else: + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": context.state_group, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in context.current_state_ids.iteritems() + ], + ) + + # Prefill the state group cache with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=context.state_group, + value=dict(context.current_state_ids), + full=True, + ) + + self._simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + { + "state_group": state_group_id, + "event_id": event_id, + } + for event_id, state_group_id in state_groups.iteritems() + ], + ) + + for event_id, state_group_id in state_groups.iteritems(): + txn.call_after( + self._get_state_group_for_event.prefill, + (event_id,), state_group_id + ) + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = (""" + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """) + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + def get_next_state_group(self): return self._state_groups_id_gen.get_next() |