diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/_base.py | 43 | ||||
-rw-r--r-- | synapse/storage/state.py | 267 | ||||
-rw-r--r-- | synapse/storage/stream.py | 3 |
3 files changed, 224 insertions, 89 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 0872a438f1..d4751769e4 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,6 +19,7 @@ from synapse.util.async import ObservableDeferred from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache +from synapse.util.dictionary_cache import DictionaryCache import synapse.metrics from util.id_generators import IdGenerator, StreamIdGenerator @@ -73,6 +74,11 @@ class Cache(object): self.thread = None caches_by_name[name] = self.cache + class Sentinel(object): + __slots__ = [] + + self.sentinel = Sentinel() + def check_thread(self): expected_thread = self.thread if expected_thread is None: @@ -84,22 +90,33 @@ class Cache(object): ) def get(self, *keyargs): - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + try: + if len(keyargs) != self.keylen: + raise ValueError("Expected a key to have %d items", self.keylen) - if keyargs in self.cache: - cache_counter.inc_hits(self.name) - return self.cache[keyargs] + val = self.cache.get(keyargs, self.sentinel) + if val is not self.sentinel: + cache_counter.inc_hits(self.name) + return val - cache_counter.inc_misses(self.name) - raise KeyError() + cache_counter.inc_misses(self.name) + raise KeyError() + except KeyError: + raise + except: + logger.exception("Cache.get failed for %s" % (self.name,)) + raise def update(self, sequence, *args): - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(*args) + try: + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + self.prefill(*args) + except: + logger.exception("Cache.update failed for %s" % (self.name,)) + raise def prefill(self, *args): # because I can't *keyargs, value keyargs = args[:-1] @@ -368,6 +385,8 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) + self._state_group_cache = DictionaryCache("*stateGroupCache*", 100000) + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 55c6d52890..48a4023558 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -17,6 +17,7 @@ from ._base import SQLBaseStore, cached, cachedInlineCallbacks from twisted.internet import defer +from synapse.util import unwrapFirstError from synapse.util.stringutils import random_string import logging @@ -44,52 +45,38 @@ class StateStore(SQLBaseStore): """ @defer.inlineCallbacks - def get_state_groups(self, event_ids): + def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids The return value is a dict mapping group names to lists of events. """ - def f(txn): - groups = set() - for event_id in event_ids: - group = self._simple_select_one_onecol_txn( - txn, - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - ) - if group: - groups.add(group) - - res = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", - ) - - res[group] = state_ids - - return res + event_and_groups = yield defer.gatherResults( + [ + self._get_state_group_for_event( + room_id, event_id, + ).addCallback(lambda group, event_id: (event_id, group), event_id) + for event_id in event_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) - states = yield self.runInteraction( - "get_state_groups", - f, - ) + groups = set(group for _, group in event_and_groups if group) - state_list = yield defer.gatherResults( + group_to_state = yield defer.gatherResults( [ - self._fetch_events_for_group(group, vals) - for group, vals in states.items() + self._get_state_for_group( + group, + ).addCallback(lambda state_dict, group: (group, state_dict), group) + for group in groups ], consumeErrors=True, - ) + ).addErrback(unwrapFirstError) - defer.returnValue(dict(state_list)) + defer.returnValue({ + group: state_map.values() + for group, state_map in group_to_state + }) @cached(num_args=1) def _fetch_events_for_group(self, key, events): @@ -205,65 +192,195 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @defer.inlineCallbacks - def get_state_for_events(self, room_id, event_ids): + @cached(num_args=2, lru=True, max_entries=10000) + def _get_state_groups_from_group(self, group, types): def f(txn): - groups = set() - event_to_group = {} - for event_id in event_ids: - # TODO: Remove this loop. - group = self._simple_select_one_onecol_txn( - txn, - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - ) - if group: - event_to_group[event_id] = group - groups.add(group) - - group_to_state_ids = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), ) + else: + where_clause = "" + + sql = ( + "SELECT event_id FROM state_groups_state WHERE" + " state_group = ? %s" + ) % (where_clause,) + + args = [group] + if types is not None: + args.extend([i for typ in types for i in typ]) + + txn.execute(sql, args) + + return group, [ + r[0] + for r in txn.fetchall() + ] + + return self.runInteraction( + "_get_state_groups_from_group", + f, + ) + + @cached(num_args=3, lru=True, max_entries=20000) + def _get_state_for_event_id(self, room_id, event_id, types): + def f(txn): + type_and_state_sql = " OR ".join([ + "(type = ? AND state_key = ?)" + if typ[1] is not None + else "type = ?" + for typ in types + ]) - group_to_state_ids[group] = state_ids + sql = ( + "SELECT sg.event_id FROM state_groups_state as sg" + " INNER JOIN event_to_state_groups as e" + " ON e.state_group = sg.state_group" + " WHERE e.event_id = ? AND (%s)" + ) % (type_and_state_sql,) + + args = [event_id] + for typ, state_key in types: + args.extend( + [typ, state_key] if state_key is not None else [typ] + ) + txn.execute(sql, args) - return event_to_group, group_to_state_ids + return event_id, [ + r[0] + for r in txn.fetchall() + ] - res = yield self.runInteraction( - "annotate_events_with_state_groups", + return self.runInteraction( + "_get_state_for_event_id", f, ) - event_to_group, group_to_state_ids = res + @defer.inlineCallbacks + def get_state_for_events(self, room_id, event_ids, types): + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. The state dicts will only have the type/state_keys + that are in the `types` list. + + Args: + room_id (str) + event_ids (list) + types (list): List of (type, state_key) tuples which are used to + filter the state fetched. `state_key` may be None, which matches + any `state_key` + + Returns: + deferred: A list of dicts corresponding to the event_ids given. + The dicts are mappings from (type, state_key) -> state_events + """ + event_and_groups = yield defer.gatherResults( + [ + self._get_state_group_for_event( + room_id, event_id, + ).addCallback(lambda group, event_id: (event_id, group), event_id) + for event_id in event_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + groups = set(group for _, group in event_and_groups) - state_list = yield defer.gatherResults( + res = yield defer.gatherResults( [ - self._fetch_events_for_group(group, vals) - for group, vals in group_to_state_ids.items() + self._get_state_for_group( + group, types + ).addCallback(lambda state_dict, group: (group, state_dict), group) + for group in groups ], consumeErrors=True, - ) + ).addErrback(unwrapFirstError) - state_dict = { - group: { - (ev.type, ev.state_key): ev - for ev in state - } - for group, state in state_list + group_to_state = dict(res) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in event_and_groups } defer.returnValue([ - state_dict.get(event_to_group.get(event, None), None) + event_to_state[event] for event in event_ids ]) + @cached(num_args=2, lru=True, max_entries=100000) + def _get_state_group_for_event(self, room_id, event_id): + return self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={ + "event_id": event_id, + }, + retcol="state_group", + allow_none=True, + desc="_get_state_group_for_event", + ) + + @defer.inlineCallbacks + def _get_state_for_group(self, group, types=None): + is_all, state_dict = self._state_group_cache.get(group) + + type_to_key = {} + missing_types = set() + if types is not None: + for typ, state_key in types: + if state_key is None: + type_to_key[typ] = None + missing_types.add((typ, state_key)) + else: + if type_to_key.get(typ, object()) is not None: + type_to_key.setdefault(typ, set()).add(state_key) + + if (typ, state_key) not in state_dict: + missing_types.add((typ, state_key)) + + if is_all and types is None: + defer.returnValue(state_dict) + + if is_all or (types is not None and not missing_types): + def include(typ, state_key): + sentinel = object() + valid_state_keys = type_to_key.get(typ, sentinel) + if valid_state_keys is sentinel: + return False + if valid_state_keys is None: + return True + if state_key in valid_state_keys: + return True + return False + + defer.returnValue({ + k: v + for k, v in state_dict.items() + if include(k[0], k[1]) + }) + + # Okay, so we have some missing_types, lets fetch them. + cache_seq_num = self._state_group_cache.sequence + _, state_ids = yield self._get_state_groups_from_group( + group, + frozenset(types) if types else None + ) + state_events = yield self._get_events(state_ids, get_prev_content=False) + state_dict = { + (e.type, e.state_key): e + for e in state_events + } + + # Update the cache + self._state_group_cache.update( + cache_seq_num, + key=group, + value=state_dict, + full=(types is None), + ) + + defer.returnValue(state_dict) + def _make_group_id(clock): return str(int(clock.time_msec())) + random_string(5) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index af45fc5619..9db259d5fc 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -300,8 +300,7 @@ class StreamStore(SQLBaseStore): defer.returnValue((events, token)) @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token, - with_feedback=False, from_token=None): + def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): # TODO (erikj): Handle compressed feedback end_token = RoomStreamToken.parse_stream_token(end_token) |