diff options
27 files changed, 927 insertions, 351 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7736d14fb5..58a6d6a0ed 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -23,7 +23,7 @@ from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, ) from synapse.util import unwrapFirstError -from synapse.util.expiringcache import ExpiringCache +from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.events import FrozenEvent import synapse.metrics diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f7155fd8d3..2bfd0a40e0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -229,12 +229,14 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_server(self, server_name, room_id, events): - states = yield self.store.get_state_for_events( - room_id, [e.event_id for e in events], + event_to_state = yield self.store.get_state_for_events( + room_id, frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, None), + ) ) - events_and_states = zip(events, states) - def redact_disallowed(event_and_state): event, state = event_and_state @@ -271,9 +273,10 @@ class FederationHandler(BaseHandler): return event - res = map(redact_disallowed, events_and_states) - - logger.info("_filter_events_for_server %r", res) + res = map(redact_disallowed, [ + (e, event_to_state[e.event_id]) + for e in events + ]) defer.returnValue(res) @@ -503,7 +506,7 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) states = yield defer.gatherResults([ - self.state_handler.resolve_state_groups([e]) + self.state_handler.resolve_state_groups(room_id, [e]) for e in event_ids ]) states = dict(zip(event_ids, [s[1] for s in states])) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9d6d4f0978..29e81085d1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -137,15 +137,15 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, room_id, events): - states = yield self.store.get_state_for_events( - room_id, [e.event_id for e in events], + event_id_to_state = yield self.store.get_state_for_events( + room_id, frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) ) - events_and_states = zip(events, states) - - def allowed(event_and_state): - event, state = event_and_state - + def allowed(event, state): if event.type == EventTypes.RoomHistoryVisibility: return True @@ -175,10 +175,10 @@ class MessageHandler(BaseHandler): return True - events_and_states = filter(allowed, events_and_states) defer.returnValue([ - ev - for ev, _ in events_and_states + event + for event in events + if allowed(event, event_id_to_state[event.event_id]) ]) @defer.inlineCallbacks @@ -401,10 +401,14 @@ class MessageHandler(BaseHandler): except: logger.exception("Failed to get snapshot") - yield defer.gatherResults( - [handle_room(e) for e in room_list], - consumeErrors=True - ).addErrback(unwrapFirstError) + # Only do N rooms at once + n = 5 + d_list = [handle_room(e) for e in room_list] + for i in range(0, len(d_list), n): + yield defer.gatherResults( + d_list[i:i + n], + consumeErrors=True + ).addErrback(unwrapFirstError) ret = { "rooms": rooms_ret, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6cff6230c1..7206ae23d7 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -294,15 +294,15 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, room_id, events): - states = yield self.store.get_state_for_events( - room_id, [e.event_id for e in events], + event_id_to_state = yield self.store.get_state_for_events( + room_id, frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) ) - events_and_states = zip(events, states) - - def allowed(event_and_state): - event, state = event_and_state - + def allowed(event, state): if event.type == EventTypes.RoomHistoryVisibility: return True @@ -331,10 +331,11 @@ class SyncHandler(BaseHandler): return membership == Membership.INVITE return True - events_and_states = filter(allowed, events_and_states) + defer.returnValue([ - ev - for ev, _ in events_and_states + event + for event in events + if allowed(event, event_id_to_state[event.event_id]) ]) @defer.inlineCallbacks diff --git a/synapse/state.py b/synapse/state.py index 80da90a72c..1fe4d066bd 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor -from synapse.util.expiringcache import ExpiringCache +from synapse.util.caches.expiringcache import ExpiringCache from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.auth import AuthEventTypes @@ -96,7 +96,7 @@ class StateHandler(object): cache.ts = self.clock.time_msec() state = cache.state else: - res = yield self.resolve_state_groups(event_ids) + res = yield self.resolve_state_groups(room_id, event_ids) state = res[1] if event_type: @@ -155,13 +155,13 @@ class StateHandler(object): if event.is_state(): ret = yield self.resolve_state_groups( - [e for e, _ in event.prev_events], + event.room_id, [e for e, _ in event.prev_events], event_type=event.type, state_key=event.state_key, ) else: ret = yield self.resolve_state_groups( - [e for e, _ in event.prev_events], + event.room_id, [e for e, _ in event.prev_events], ) group, curr_state, prev_state = ret @@ -180,7 +180,7 @@ class StateHandler(object): @defer.inlineCallbacks @log_function - def resolve_state_groups(self, event_ids, event_type=None, state_key=""): + def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -205,7 +205,7 @@ class StateHandler(object): ) state_groups = yield self.store.get_state_groups( - event_ids + room_id, event_ids ) logger.debug( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 73eea157a4..1444767a52 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,25 +15,22 @@ import logging from synapse.api.errors import StoreError -from synapse.util.async import ObservableDeferred from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext -from synapse.util.lrucache import LruCache +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.caches.descriptors import Cache import synapse.metrics from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer -from collections import namedtuple, OrderedDict +from collections import namedtuple -import functools -import inspect import sys import time import threading -DEBUG_CACHES = False logger = logging.getLogger(__name__) @@ -49,208 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time") sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) -caches_by_name = {} -cache_counter = metrics.register_cache( - "cache", - lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, - labels=["name"], -) - - -_CacheSentinel = object() - - -class Cache(object): - - def __init__(self, name, max_entries=1000, keylen=1, lru=True): - if lru: - self.cache = LruCache(max_size=max_entries) - self.max_entries = None - else: - self.cache = OrderedDict() - self.max_entries = max_entries - - self.name = name - self.keylen = keylen - self.sequence = 0 - self.thread = None - caches_by_name[name] = self.cache - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get(self, key, default=_CacheSentinel): - val = self.cache.get(key, _CacheSentinel) - if val is not _CacheSentinel: - cache_counter.inc_hits(self.name) - return val - - cache_counter.inc_misses(self.name) - - if default is _CacheSentinel: - raise KeyError() - else: - return default - - def update(self, sequence, key, value): - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value) - - def prefill(self, key, value): - if self.max_entries is not None: - while len(self.cache) >= self.max_entries: - self.cache.popitem(last=False) - - self.cache[key] = value - - def invalidate(self, key): - self.check_thread() - if not isinstance(key, tuple): - raise TypeError( - "The cache key must be a tuple not %r" % (type(key),) - ) - - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 - self.cache.pop(key, None) - - def invalidate_all(self): - self.check_thread() - self.sequence += 1 - self.cache.clear() - - -class CacheDescriptor(object): - """ A method decorator that applies a memoizing cache around the function. - - This caches deferreds, rather than the results themselves. Deferreds that - fail are removed from the cache. - - The function is presumed to take zero or more arguments, which are used in - a tuple as the key for the cache. Hits are served directly from the cache; - misses use the function body to generate the value. - - The wrapped function has an additional member, a callable called - "invalidate". This can be used to remove individual entries from the cache. - - The wrapped function has another additional callable, called "prefill", - which can be used to insert values into the cache specifically, without - calling the calculation function. - """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=True, - inlineCallbacks=False): - self.orig = orig - - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - - self.max_entries = max_entries - self.num_args = num_args - self.lru = lru - - self.arg_names = inspect.getargspec(orig).args[1:num_args+1] - - if len(self.arg_names) < self.num_args: - raise Exception( - "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" - % (orig.__name__,) - ) - - self.cache = Cache( - name=self.orig.__name__, - max_entries=self.max_entries, - keylen=self.num_args, - lru=self.lru, - ) - - def __get__(self, obj, objtype=None): - - @functools.wraps(self.orig) - def wrapped(*args, **kwargs): - arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) - try: - cached_result_d = self.cache.get(cache_key) - - observer = cached_result_d.observe() - if DEBUG_CACHES: - @defer.inlineCallbacks - def check_result(cached_result): - actual_result = yield self.function_to_call(obj, *args, **kwargs) - if actual_result != cached_result: - logger.error( - "Stale cache entry %s%r: cached: %r, actual %r", - self.orig.__name__, cache_key, - cached_result, actual_result, - ) - raise ValueError("Stale cache entry") - defer.returnValue(cached_result) - observer.addCallback(check_result) - - return observer - except KeyError: - # Get the sequence number of the cache before reading from the - # database so that we can tell if the cache is invalidated - # while the SELECT is executing (SYN-369) - sequence = self.cache.sequence - - ret = defer.maybeDeferred( - self.function_to_call, - obj, *args, **kwargs - ) - - def onErr(f): - self.cache.invalidate(cache_key) - return f - - ret.addErrback(onErr) - - ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, cache_key, ret) - - return ret.observe() - - wrapped.invalidate = self.cache.invalidate - wrapped.invalidate_all = self.cache.invalidate_all - wrapped.prefill = self.cache.prefill - - obj.__dict__[self.orig.__name__] = wrapped - - return wrapped - - -def cached(max_entries=1000, num_args=1, lru=True): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - lru=lru - ) - - -def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - lru=lru, - inlineCallbacks=True, - ) - class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object @@ -372,6 +167,8 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) + self._state_group_cache = DictionaryCache("*stateGroupCache*", 100000) + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index f3947bbe89..d92028ea43 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from synapse.api.errors import SynapseError diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 910b6598a7..25cc84eb95 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -15,7 +15,8 @@ from twisted.internet import defer -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from syutil.base64util import encode_base64 import logging diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 49b8e37cfd..ffd6daa880 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore, cachedInlineCallbacks +from _base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 576cf670cc..4f91a2b87c 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from twisted.internet import defer diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9b88ca7b39..5305b7e122 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer import logging diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index b79d6683ca..cac1a5657e 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 4eaa088b36..aa446f94c6 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -17,7 +17,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError, Codes -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached class RegistrationStore(SQLBaseStore): diff --git a/synapse/storage/room.py b/synapse/storage/room.py index dd5bc2c8fb..5e07b7e0e5 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -17,7 +17,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks import collections import logging diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 9f14f38f24..8eee2dfbcc 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -17,7 +17,8 @@ from twisted.internet import defer from collections import namedtuple -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from synapse.api.constants import Membership from synapse.types import UserID diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7ce51b9bdc..185f88fd7c 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import ( + cached, cachedInlineCallbacks, cachedList +) from twisted.internet import defer @@ -44,59 +47,25 @@ class StateStore(SQLBaseStore): """ @defer.inlineCallbacks - def get_state_groups(self, event_ids): + def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids The return value is a dict mapping group names to lists of events. """ + if not event_ids: + defer.returnValue({}) - def f(txn): - groups = set() - for event_id in event_ids: - group = self._simple_select_one_onecol_txn( - txn, - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - ) - if group: - groups.add(group) - - res = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", - ) - - res[group] = state_ids - - return res - - states = yield self.runInteraction( - "get_state_groups", - f, + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, ) - state_list = yield defer.gatherResults( - [ - self._fetch_events_for_group(group, vals) - for group, vals in states.items() - ], - consumeErrors=True, - ) + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups) - defer.returnValue(dict(state_list)) - - def _fetch_events_for_group(self, key, events): - return self._get_events( - events, get_prev_content=False - ).addCallback( - lambda evs: (key, evs) - ) + defer.returnValue({ + group: state_map.values() + for group, state_map in group_to_state.items() + }) def _store_state_groups_txn(self, txn, event, context): return self._store_mult_state_groups_txn(txn, [(event, context)]) @@ -204,64 +173,250 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) + def _get_state_groups_from_groups(self, groups_and_types): + """Returns dictionary state_group -> state event ids + + Args: + groups_and_types (list): list of 2-tuple (`group`, `types`) + """ + def f(txn): + results = {} + for group, types in groups_and_types: + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), + ) + else: + where_clause = "" + + sql = ( + "SELECT event_id FROM state_groups_state WHERE" + " state_group = ? %s" + ) % (where_clause,) + + args = [group] + if types is not None: + args.extend([i for typ in types for i in typ]) + + txn.execute(sql, args) + + results[group] = [r[0] for r in txn.fetchall()] + + return results + + return self.runInteraction( + "_get_state_groups_from_groups", + f, + ) + @defer.inlineCallbacks - def get_state_for_events(self, room_id, event_ids): + def get_state_for_events(self, room_id, event_ids, types): + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. The state dicts will only have the type/state_keys + that are in the `types` list. + + Args: + room_id (str) + event_ids (list) + types (list): List of (type, state_key) tuples which are used to + filter the state fetched. `state_key` may be None, which matches + any `state_key` + + Returns: + deferred: A list of dicts corresponding to the event_ids given. + The dicts are mappings from (type, state_key) -> state_events + """ + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups, types) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in event_to_groups.items() + } + + defer.returnValue({event: event_to_state[event] for event in event_ids}) + + @cached(num_args=2, lru=True, max_entries=100000) + def _get_state_group_for_event(self, room_id, event_id): + return self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={ + "event_id": event_id, + }, + retcol="state_group", + allow_none=True, + desc="_get_state_group_for_event", + ) + + @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", + num_args=2) + def _get_state_group_for_events(self, room_id, event_ids): + """Returns mapping event_id -> state_group + """ def f(txn): - groups = set() - event_to_group = {} + results = {} for event_id in event_ids: - # TODO: Remove this loop. - group = self._simple_select_one_onecol_txn( + results[event_id] = self._simple_select_one_onecol_txn( txn, table="event_to_state_groups", - keyvalues={"event_id": event_id}, + keyvalues={ + "event_id": event_id, + }, retcol="state_group", allow_none=True, ) - if group: - event_to_group[event_id] = group - groups.add(group) - group_to_state_ids = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", + return results + + return self.runInteraction("_get_state_group_for_events", f) + + def _get_some_state_from_cache(self, group, types): + """Checks if group is in cache. See `_get_state_for_groups` + + Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). + `missing_types` is the list of types that aren't in the cache for that + group. `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + + Args: + group: The state group to lookup + types (list): List of 2-tuples of the form (`type`, `state_key`), + where a `state_key` of `None` matches all state_keys for the + `type`. + """ + is_all, state_dict = self._state_group_cache.get(group) + + type_to_key = {} + missing_types = set() + for typ, state_key in types: + if state_key is None: + type_to_key[typ] = None + missing_types.add((typ, state_key)) + else: + if type_to_key.get(typ, object()) is not None: + type_to_key.setdefault(typ, set()).add(state_key) + + if (typ, state_key) not in state_dict: + missing_types.add((typ, state_key)) + + sentinel = object() + + def include(typ, state_key): + valid_state_keys = type_to_key.get(typ, sentinel) + if valid_state_keys is sentinel: + return False + if valid_state_keys is None: + return True + if state_key in valid_state_keys: + return True + return False + + got_all = not (missing_types or types is None) + + return { + k: v for k, v in state_dict.items() + if include(k[0], k[1]) + }, missing_types, got_all + + def _get_all_state_from_cache(self, group): + """Checks if group is in cache. See `_get_state_for_groups` + + Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool + indicating if we successfully retrieved all requests state from the + cache, if False we need to query the DB for the missing state. + + Args: + group: The state group to lookup + """ + is_all, state_dict = self._state_group_cache.get(group) + return state_dict, is_all + + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, types=None): + """Given list of groups returns dict of group -> list of state events + with matching types. `types` is a list of `(type, state_key)`, where + a `state_key` of None matches all state_keys. If `types` is None then + all events are returned. + """ + results = {} + missing_groups_and_types = [] + if types is not None: + for group in set(groups): + state_dict, missing_types, got_all = self._get_some_state_from_cache( + group, types + ) + results[group] = state_dict + + if not got_all: + missing_groups_and_types.append((group, missing_types)) + else: + for group in set(groups): + state_dict, got_all = self._get_all_state_from_cache( + group ) + results[group] = state_dict - group_to_state_ids[group] = state_ids + if not got_all: + missing_groups_and_types.append((group, None)) - return event_to_group, group_to_state_ids + if not missing_groups_and_types: + defer.returnValue({ + group: { + type_tuple: event + for type_tuple, event in state.items() + if event + } + for group, state in results.items() + }) - res = yield self.runInteraction( - "annotate_events_with_state_groups", - f, - ) + # Okay, so we have some missing_types, lets fetch them. + cache_seq_num = self._state_group_cache.sequence - event_to_group, group_to_state_ids = res + group_state_dict = yield self._get_state_groups_from_groups( + missing_groups_and_types + ) - state_list = yield defer.gatherResults( - [ - self._fetch_events_for_group(group, vals) - for group, vals in group_to_state_ids.items() - ], - consumeErrors=True, + state_events = yield self._get_events( + [e_id for l in group_state_dict.values() for e_id in l], + get_prev_content=False ) - state_dict = { - group: { - (ev.type, ev.state_key): ev - for ev in state - } - for group, state in state_list - } + state_events = {e.event_id: e for e in state_events} + + # Now we want to update the cache with all the things we fetched + # from the database. + for group, state_ids in group_state_dict.items(): + if types: + # We delibrately put key -> None mappings into the cache to + # cache absence of the key, on the assumption that if we've + # explicitly asked for some types then we will probably ask + # for them again. + state_dict = {key: None for key in types} + state_dict.update(results[group]) + else: + state_dict = results[group] + + for event_id in state_ids: + state_event = state_events[event_id] + state_dict[(state_event.type, state_event.state_key)] = state_event + + self._state_group_cache.update( + cache_seq_num, + key=group, + value=state_dict, + full=(types is None), + ) + + results[group].update({ + key: value for key, value in state_dict.items() if value + }) - defer.returnValue([ - state_dict.get(event_to_group.get(event, None), None) - for event in event_ids - ]) + defer.returnValue(results) def _make_group_id(clock): diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index af45fc5619..d7fe423f5a 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -36,6 +36,7 @@ what sort order was used: from twisted.internet import defer from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken from synapse.util.logutils import log_function @@ -299,9 +300,8 @@ class StreamStore(SQLBaseStore): defer.returnValue((events, token)) - @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token, - with_feedback=False, from_token=None): + @cachedInlineCallbacks(num_args=4) + def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): # TODO (erikj): Handle compressed feedback end_token = RoomStreamToken.parse_stream_token(end_token) diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 624da4a9dc..c8c7e6591a 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from collections import namedtuple diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py new file mode 100644 index 0000000000..da0e06a468 --- /dev/null +++ b/synapse/util/caches/__init__.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse.metrics + +DEBUG_CACHES = False + +metrics = synapse.metrics.get_metrics_for("synapse.util.caches") + +caches_by_name = {} +cache_counter = metrics.register_cache( + "cache", + lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, + labels=["name"], +) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py new file mode 100644 index 0000000000..83bfec2f02 --- /dev/null +++ b/synapse/util/caches/descriptors.py @@ -0,0 +1,377 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.util.async import ObservableDeferred +from synapse.util import unwrapFirstError +from synapse.util.caches.lrucache import LruCache + +from . import caches_by_name, DEBUG_CACHES, cache_counter + +from twisted.internet import defer + +from collections import OrderedDict + +import functools +import inspect +import threading + +logger = logging.getLogger(__name__) + + +_CacheSentinel = object() + + +class Cache(object): + + def __init__(self, name, max_entries=1000, keylen=1, lru=True): + if lru: + self.cache = LruCache(max_size=max_entries) + self.max_entries = None + else: + self.cache = OrderedDict() + self.max_entries = max_entries + + self.name = name + self.keylen = keylen + self.sequence = 0 + self.thread = None + caches_by_name[name] = self.cache + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get(self, key, default=_CacheSentinel): + val = self.cache.get(key, _CacheSentinel) + if val is not _CacheSentinel: + cache_counter.inc_hits(self.name) + return val + + cache_counter.inc_misses(self.name) + + if default is _CacheSentinel: + raise KeyError() + else: + return default + + def update(self, sequence, key, value): + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + self.prefill(key, value) + + def prefill(self, key, value): + if self.max_entries is not None: + while len(self.cache) >= self.max_entries: + self.cache.popitem(last=False) + + self.cache[key] = value + + def invalidate(self, key): + self.check_thread() + if not isinstance(key, tuple): + raise TypeError( + "The cache key must be a tuple not %r" % (type(key),) + ) + + # Increment the sequence number so that any SELECT statements that + # raced with the INSERT don't update the cache (SYN-369) + self.sequence += 1 + self.cache.pop(key, None) + + def invalidate_all(self): + self.check_thread() + self.sequence += 1 + self.cache.clear() + + +class CacheDescriptor(object): + """ A method decorator that applies a memoizing cache around the function. + + This caches deferreds, rather than the results themselves. Deferreds that + fail are removed from the cache. + + The function is presumed to take zero or more arguments, which are used in + a tuple as the key for the cache. Hits are served directly from the cache; + misses use the function body to generate the value. + + The wrapped function has an additional member, a callable called + "invalidate". This can be used to remove individual entries from the cache. + + The wrapped function has another additional callable, called "prefill", + which can be used to insert values into the cache specifically, without + calling the calculation function. + """ + def __init__(self, orig, max_entries=1000, num_args=1, lru=True, + inlineCallbacks=False): + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + + self.max_entries = max_entries + self.num_args = num_args + self.lru = lru + + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + + self.cache = Cache( + name=self.orig.__name__, + max_entries=self.max_entries, + keylen=self.num_args, + lru=self.lru, + ) + + def __get__(self, obj, objtype=None): + + @functools.wraps(self.orig) + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + try: + cached_result_d = self.cache.get(cache_key) + + observer = cached_result_d.observe() + if DEBUG_CACHES: + @defer.inlineCallbacks + def check_result(cached_result): + actual_result = yield self.function_to_call(obj, *args, **kwargs) + if actual_result != cached_result: + logger.error( + "Stale cache entry %s%r: cached: %r, actual %r", + self.orig.__name__, cache_key, + cached_result, actual_result, + ) + raise ValueError("Stale cache entry") + defer.returnValue(cached_result) + observer.addCallback(check_result) + + return observer + except KeyError: + # Get the sequence number of the cache before reading from the + # database so that we can tell if the cache is invalidated + # while the SELECT is executing (SYN-369) + sequence = self.cache.sequence + + ret = defer.maybeDeferred( + self.function_to_call, + obj, *args, **kwargs + ) + + def onErr(f): + self.cache.invalidate(cache_key) + return f + + ret.addErrback(onErr) + + ret = ObservableDeferred(ret, consumeErrors=True) + self.cache.update(sequence, cache_key, ret) + + return ret.observe() + + wrapped.invalidate = self.cache.invalidate + wrapped.invalidate_all = self.cache.invalidate_all + wrapped.prefill = self.cache.prefill + + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + +class CacheListDescriptor(object): + """Wraps an existing cache to support bulk fetching of keys. + + Given a list of keys it looks in the cache to find any hits, then passes + the list of missing keys to the wrapped fucntion. + """ + + def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + """ + Args: + orig (function) + cache (Cache) + list_name (str): Name of the argument which is the bulk lookup list + num_args (int) + inlineCallbacks (bool): Whether orig is a generator that should + be wrapped by defer.inlineCallbacks + """ + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + + self.num_args = num_args + self.list_name = list_name + + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + self.list_pos = self.arg_names.index(self.list_name) + + self.cache = cache + + self.sentinel = object() + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + + if self.list_name not in self.arg_names: + raise Exception( + "Couldn't see arguments %r for %r." + % (self.list_name, cache.name,) + ) + + def __get__(self, obj, objtype=None): + + @functools.wraps(self.orig) + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] + list_args = arg_dict[self.list_name] + + # cached is a dict arg -> deferred, where deferred results in a + # 2-tuple (`arg`, `result`) + cached = {} + missing = [] + for arg in list_args: + key = list(keyargs) + key[self.list_pos] = arg + + try: + res = self.cache.get(tuple(key)).observe() + res.addCallback(lambda r, arg: (arg, r), arg) + cached[arg] = res + except KeyError: + missing.append(arg) + + if missing: + sequence = self.cache.sequence + args_to_call = dict(arg_dict) + args_to_call[self.list_name] = missing + + ret_d = defer.maybeDeferred( + self.function_to_call, + **args_to_call + ) + + ret_d = ObservableDeferred(ret_d) + + # We need to create deferreds for each arg in the list so that + # we can insert the new deferred into the cache. + for arg in missing: + observer = ret_d.observe() + observer.addCallback(lambda r, arg: r[arg], arg) + + observer = ObservableDeferred(observer) + + key = list(keyargs) + key[self.list_pos] = arg + self.cache.update(sequence, tuple(key), observer) + + def invalidate(f, key): + self.cache.invalidate(key) + return f + observer.addErrback(invalidate, tuple(key)) + + res = observer.observe() + res.addCallback(lambda r, arg: (arg, r), arg) + + cached[arg] = res + + return defer.gatherResults( + cached.values(), + consumeErrors=True, + ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) + + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + +def cached(max_entries=1000, num_args=1, lru=True): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru + ) + + +def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru, + inlineCallbacks=True, + ) + + +def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): + """Creates a descriptor that wraps a function in a `CacheListDescriptor`. + + Used to do batch lookups for an already created cache. A single argument + is specified as a list that is iterated through to lookup keys in the + original cache. A new list consisting of the keys that weren't in the cache + get passed to the original function, the result of which is stored in the + cache. + + Args: + cache (Cache): The underlying cache to use. + list_name (str): The name of the argument that is the list to use to + do batch lookups in the cache. + num_args (int): Number of arguments to use as the key in the cache. + inlineCallbacks (bool): Should the function be wrapped in an + `defer.inlineCallbacks`? + + Example: + + class Example(object): + @cached(num_args=2) + def do_something(self, first_arg): + ... + + @cachedList(do_something.cache, list_name="second_args", num_args=2) + def batch_do_something(self, first_arg, second_args): + ... + """ + return lambda orig: CacheListDescriptor( + orig, + cache=cache, + list_name=list_name, + num_args=num_args, + inlineCallbacks=inlineCallbacks, + ) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py new file mode 100644 index 0000000000..e69adf62fe --- /dev/null +++ b/synapse/util/caches/dictionary_cache.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.caches.lrucache import LruCache +from collections import namedtuple +from . import caches_by_name, cache_counter +import threading +import logging + + +logger = logging.getLogger(__name__) + + +DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) + + +class DictionaryCache(object): + """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. + fetching a subset of dictionary keys for a particular key. + """ + + def __init__(self, name, max_entries=1000): + self.cache = LruCache(max_size=max_entries) + + self.name = name + self.sequence = 0 + self.thread = None + # caches_by_name[name] = self.cache + + class Sentinel(object): + __slots__ = [] + + self.sentinel = Sentinel() + caches_by_name[name] = self.cache + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get(self, key, dict_keys=None): + entry = self.cache.get(key, self.sentinel) + if entry is not self.sentinel: + cache_counter.inc_hits(self.name) + + if dict_keys is None: + return DictionaryEntry(entry.full, dict(entry.value)) + else: + return DictionaryEntry(entry.full, { + k: entry.value[k] + for k in dict_keys + if k in entry.value + }) + + cache_counter.inc_misses(self.name) + return DictionaryEntry(False, {}) + + def invalidate(self, key): + self.check_thread() + + # Increment the sequence number so that any SELECT statements that + # raced with the INSERT don't update the cache (SYN-369) + self.sequence += 1 + self.cache.pop(key, None) + + def invalidate_all(self): + self.check_thread() + self.sequence += 1 + self.cache.clear() + + def update(self, sequence, key, value, full=False): + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + if full: + self._insert(key, value) + else: + self._update_or_insert(key, value) + + def _update_or_insert(self, key, value): + entry = self.cache.setdefault(key, DictionaryEntry(False, {})) + entry.value.update(value) + + def _insert(self, key, value): + self.cache[key] = DictionaryEntry(True, value) diff --git a/synapse/util/expiringcache.py b/synapse/util/caches/expiringcache.py index 06d1eea01b..06d1eea01b 100644 --- a/synapse/util/expiringcache.py +++ b/synapse/util/caches/expiringcache.py diff --git a/synapse/util/lrucache.py b/synapse/util/caches/lrucache.py index cacd7e45fa..cacd7e45fa 100644 --- a/synapse/util/lrucache.py +++ b/synapse/util/caches/lrucache.py diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index abee2f631d..e72cace8ff 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.util.async import ObservableDeferred -from synapse.storage._base import Cache, cached +from synapse.util.caches.descriptors import Cache, cached class CacheTestCase(unittest.TestCase): diff --git a/tests/test_state.py b/tests/test_state.py index fea25f7021..5845358754 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -69,7 +69,7 @@ class StateGroupStore(object): self._next_group = 1 - def get_state_groups(self, event_ids): + def get_state_groups(self, room_id, event_ids): groups = {} for event_id in event_ids: group = self._event_to_state_group.get(event_id) diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py new file mode 100644 index 0000000000..54ff26cd97 --- /dev/null +++ b/tests/util/test_dict_cache.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer +from tests import unittest + +from synapse.util.caches.dictionary_cache import DictionaryCache + + +class DictCacheTestCase(unittest.TestCase): + + def setUp(self): + self.cache = DictionaryCache("foobar") + + def test_simple_cache_hit_full(self): + key = "test_simple_cache_hit_full" + + v = self.cache.get(key) + self.assertEqual((False, {}), v) + + seq = self.cache.sequence + test_value = {"test": "test_simple_cache_hit_full"} + self.cache.update(seq, key, test_value, full=True) + + c = self.cache.get(key) + self.assertEqual(test_value, c.value) + + def test_simple_cache_hit_partial(self): + key = "test_simple_cache_hit_partial" + + seq = self.cache.sequence + test_value = { + "test": "test_simple_cache_hit_partial" + } + self.cache.update(seq, key, test_value, full=True) + + c = self.cache.get(key, ["test"]) + self.assertEqual(test_value, c.value) + + def test_simple_cache_miss_partial(self): + key = "test_simple_cache_miss_partial" + + seq = self.cache.sequence + test_value = { + "test": "test_simple_cache_miss_partial" + } + self.cache.update(seq, key, test_value, full=True) + + c = self.cache.get(key, ["test2"]) + self.assertEqual({}, c.value) + + def test_simple_cache_hit_miss_partial(self): + key = "test_simple_cache_hit_miss_partial" + + seq = self.cache.sequence + test_value = { + "test": "test_simple_cache_hit_miss_partial", + "test2": "test_simple_cache_hit_miss_partial2", + "test3": "test_simple_cache_hit_miss_partial3", + } + self.cache.update(seq, key, test_value, full=True) + + c = self.cache.get(key, ["test2"]) + self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) + + def test_multi_insert(self): + key = "test_simple_cache_hit_miss_partial" + + seq = self.cache.sequence + test_value_1 = { + "test": "test_simple_cache_hit_miss_partial", + } + self.cache.update(seq, key, test_value_1, full=False) + + seq = self.cache.sequence + test_value_2 = { + "test2": "test_simple_cache_hit_miss_partial2", + } + self.cache.update(seq, key, test_value_2, full=False) + + c = self.cache.get(key) + self.assertEqual( + { + "test": "test_simple_cache_hit_miss_partial", + "test2": "test_simple_cache_hit_miss_partial2", + }, + c.value + ) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index ab934bf928..fc5a904323 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -16,7 +16,7 @@ from .. import unittest -from synapse.util.lrucache import LruCache +from synapse.util.caches.lrucache import LruCache class LruCacheTestCase(unittest.TestCase): @@ -52,5 +52,3 @@ class LruCacheTestCase(unittest.TestCase): cache["key"] = 1 self.assertEquals(cache.pop("key"), 1) self.assertEquals(cache.pop("key"), None) - - |