diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 74 |
1 files changed, 72 insertions, 2 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a7c3d401d4..5673e4aa96 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -20,6 +20,7 @@ from synapse.util.stringutils import to_ascii from synapse.storage.engines import PostgresEngine from twisted.internet import defer +from collections import namedtuple import logging @@ -29,6 +30,16 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + class StateStore(SQLBaseStore): """ Keeps track of the state at a given event. @@ -98,6 +109,46 @@ class StateStore(SQLBaseStore): _get_current_state_ids_txn, ) + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + def _get_state_group_delta_txn(txn): + prev_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={ + "state_group": state_group, + }, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self._simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={ + "state_group": state_group, + }, + retcols=("type", "state_key", "event_id",) + ) + + return _GetStateGroupDelta(prev_group, { + (row["type"], row["state_key"]): row["event_id"] + for row in delta_ids + }) + return self.runInteraction( + "get_state_group_delta", + _get_state_group_delta_txn, + ) + @defer.inlineCallbacks def get_state_groups_ids(self, room_id, event_ids): if not event_ids: @@ -184,6 +235,19 @@ class StateStore(SQLBaseStore): # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. if context.prev_group: + is_in_db = self._simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": context.prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (context.prev_group,) + ) + potential_hops = self._count_state_group_hops_txn( txn, context.prev_group ) @@ -251,6 +315,12 @@ class StateStore(SQLBaseStore): ], ) + for event_id, state_group_id in state_groups.iteritems(): + txn.call_after( + self._get_state_group_for_event.prefill, + (event_id,), state_group_id + ) + def _count_state_group_hops_txn(self, txn, state_group): """Given a state group, count how many hops there are in the tree. @@ -520,8 +590,8 @@ class StateStore(SQLBaseStore): state_map = yield self.get_state_ids_for_events([event_id], types) defer.returnValue(state_map[event_id]) - @cached(num_args=2, max_entries=50000) - def _get_state_group_for_event(self, room_id, event_id): + @cached(max_entries=50000) + def _get_state_group_for_event(self, event_id): return self._simple_select_one_onecol( table="event_to_state_groups", keyvalues={ |