From 80a61330ee794147b213b1d54f2292a1c9adc002 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Jul 2015 11:41:55 +0100 Subject: Add basic storage functions for handling of receipts --- synapse/storage/_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8d33def6c6..8f812f0fd7 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -329,13 +329,14 @@ class SQLBaseStore(object): self.database_engine = hs.database_engine - self._stream_id_gen = StreamIdGenerator() + self._stream_id_gen = StreamIdGenerator("events", "stream_ordering") self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) + self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id") def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() -- cgit 1.4.1 From 39e21ea51cf94f436f1f845f24c1b04f11b92c6f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 27 Jul 2015 13:57:29 +0100 Subject: Add support for using keyword arguments with cached functions --- synapse/storage/_base.py | 40 ++++++++++++++++++++++++++++++++++------ synapse/storage/keys.py | 5 ++--- synapse/storage/push_rule.py | 8 +++----- synapse/storage/receipts.py | 5 ++--- synapse/storage/room.py | 5 ++--- synapse/storage/state.py | 5 ++--- 6 files changed, 45 insertions(+), 23 deletions(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8f812f0fd7..f1265541ba 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -27,6 +27,7 @@ from twisted.internet import defer from collections import namedtuple, OrderedDict import functools +import inspect import sys import time import threading @@ -141,13 +142,28 @@ class CacheDescriptor(object): which can be used to insert values into the cache specifically, without calling the calculation function. """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=False): + def __init__(self, orig, max_entries=1000, num_args=1, lru=False, + inlineCallbacks=False): self.orig = orig + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + self.max_entries = max_entries self.num_args = num_args self.lru = lru + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + def __get__(self, obj, objtype=None): cache = Cache( name=self.orig.__name__, @@ -158,11 +174,13 @@ class CacheDescriptor(object): @functools.wraps(self.orig) @defer.inlineCallbacks - def wrapped(*keyargs): + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] try: - cached_result = cache.get(*keyargs[:self.num_args]) + cached_result = cache.get(*keyargs) if DEBUG_CACHES: - actual_result = yield self.orig(obj, *keyargs) + actual_result = yield self.function_to_call(obj, *args, **kwargs) if actual_result != cached_result: logger.error( "Stale cache entry %s%r: cached: %r, actual %r", @@ -177,9 +195,9 @@ class CacheDescriptor(object): # while the SELECT is executing (SYN-369) sequence = cache.sequence - ret = yield self.orig(obj, *keyargs) + ret = yield self.function_to_call(obj, *args, **kwargs) - cache.update(sequence, *keyargs[:self.num_args] + (ret,)) + cache.update(sequence, *(keyargs + [ret])) defer.returnValue(ret) @@ -201,6 +219,16 @@ def cached(max_entries=1000, num_args=1, lru=False): ) +def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru, + inlineCallbacks=True, + ) + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 940a5f7e08..e3f98f0cde 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore, cached +from _base import SQLBaseStore, cachedInlineCallbacks from twisted.internet import defer @@ -71,8 +71,7 @@ class KeyStore(SQLBaseStore): desc="store_server_certificate", ) - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_all_server_verify_keys(self, server_name): rows = yield self._simple_select_list( table="server_signature_keys", diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 4cac118d17..a220f3632e 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore, cachedInlineCallbacks from twisted.internet import defer import logging @@ -23,8 +23,7 @@ logger = logging.getLogger(__name__) class PushRuleStore(SQLBaseStore): - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_push_rules_for_user(self, user_name): rows = yield self._simple_select_list( table=PushRuleTable.table_name, @@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(rows) - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_push_rules_enabled_for_user(self, user_name): results = yield self._simple_select_list( table=PushRuleEnableTable.table_name, diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 7a6af98d98..b79d6683ca 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore, cachedInlineCallbacks from twisted.internet import defer @@ -128,8 +128,7 @@ class ReceiptsStore(SQLBaseStore): def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_max_token(self) - @cached - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_graph_receipts_for_room(self, room_id): """Get receipts for sending to remote servers. """ diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 4612a8aa83..dd5bc2c8fb 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore, cachedInlineCallbacks import collections import logging @@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore): } ) - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_room_name_and_aliases(self, room_id): def f(txn): sql = ( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 47bec65497..55c6d52890 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore, cached, cachedInlineCallbacks from twisted.internet import defer @@ -189,8 +189,7 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @cached(num_args=3) - @defer.inlineCallbacks + @cachedInlineCallbacks(num_args=3) def get_current_state_for_key(self, room_id, event_type, state_key): def f(txn): sql = ( -- cgit 1.4.1 From 4d6cb8814e134eba644afeed7bd49df0c7951342 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 4 Aug 2015 09:32:23 +0100 Subject: Speed up event filtering (for ACL) logic --- synapse/handlers/federation.py | 6 ++- synapse/handlers/message.py | 6 ++- synapse/handlers/sync.py | 6 ++- synapse/storage/_base.py | 10 +++- synapse/storage/state.py | 117 ++++++++++++++++++++++++++++------------- 5 files changed, 102 insertions(+), 43 deletions(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f7155fd8d3..22f534e49a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -230,7 +230,11 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_server(self, server_name, room_id, events): states = yield self.store.get_state_for_events( - room_id, [e.event_id for e in events], + room_id, frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, None), + ) ) events_and_states = zip(events, states) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9d6d4f0978..765b14d994 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -138,7 +138,11 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, room_id, events): states = yield self.store.get_state_for_events( - room_id, [e.event_id for e in events], + room_id, frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) ) events_and_states = zip(events, states) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6cff6230c1..8f58774b31 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -295,7 +295,11 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def _filter_events_for_client(self, user_id, room_id, events): states = yield self.store.get_state_for_events( - room_id, [e.event_id for e in events], + room_id, frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) ) events_and_states = zip(events, states) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8f812f0fd7..7b76ee3b73 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -71,6 +71,11 @@ class Cache(object): self.thread = None caches_by_name[name] = self.cache + class Sentinel(object): + __slots__ = [] + + self.sentinel = Sentinel() + def check_thread(self): expected_thread = self.thread if expected_thread is None: @@ -85,9 +90,10 @@ class Cache(object): if len(keyargs) != self.keylen: raise ValueError("Expected a key to have %d items", self.keylen) - if keyargs in self.cache: + val = self.cache.get(keyargs, self.sentinel) + if val is not self.sentinel: cache_counter.inc_hits(self.name) - return self.cache[keyargs] + return val cache_counter.inc_misses(self.name) raise KeyError() diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 47bec65497..7e9bd232cf 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -17,6 +17,7 @@ from ._base import SQLBaseStore, cached from twisted.internet import defer +from synapse.util import unwrapFirstError from synapse.util.stringutils import random_string import logging @@ -206,62 +207,102 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @defer.inlineCallbacks - def get_state_for_events(self, room_id, event_ids): + @cached(num_args=3, lru=True) + def _get_state_groups_from_group(self, room_id, group, types): def f(txn): - groups = set() - event_to_group = {} - for event_id in event_ids: - # TODO: Remove this loop. - group = self._simple_select_one_onecol_txn( - txn, - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - ) - if group: - event_to_group[event_id] = group - groups.add(group) - - group_to_state_ids = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", - ) + sql = ( + "SELECT event_id FROM state_groups_state WHERE" + " room_id = ? AND state_group = ? AND (%s)" + ) % (" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),) - group_to_state_ids[group] = state_ids + args = [room_id, group] + args.extend([i for typ in types for i in typ]) + txn.execute(sql, args) - return event_to_group, group_to_state_ids + return group, [ + r[0] + for r in txn.fetchall() + ] - res = yield self.runInteraction( - "annotate_events_with_state_groups", + return self.runInteraction( + "_get_state_groups_from_group", f, ) - event_to_group, group_to_state_ids = res + @cached(num_args=3, lru=True, max_entries=100000) + def _get_state_for_event_id(self, room_id, event_id, types): + def f(txn): + type_and_state_sql = " OR ".join([ + "(type = ? AND state_key = ?)" + if typ[1] is not None + else "type = ?" + for typ in types + ]) - state_list = yield defer.gatherResults( + sql = ( + "SELECT sg.event_id FROM state_groups_state as sg" + " INNER JOIN event_to_state_groups as e" + " ON e.state_group = sg.state_group" + " WHERE e.event_id = ? AND (%s)" + ) % (type_and_state_sql,) + + args = [event_id] + for typ, state_key in types: + args.extend( + [typ, state_key] if state_key is not None else [typ] + ) + txn.execute(sql, args) + + return event_id, [ + r[0] + for r in txn.fetchall() + ] + + return self.runInteraction( + "_get_state_for_event_id", + f, + ) + + @defer.inlineCallbacks + def get_state_for_events(self, room_id, event_ids, types): + set_types = frozenset(types) + res = yield defer.gatherResults( [ - self._fetch_events_for_group(group, vals) - for group, vals in group_to_state_ids.items() + self._get_state_for_event_id( + room_id, event_id, set_types, + ) + for event_id in event_ids ], consumeErrors=True, + ).addErrback(unwrapFirstError) + + event_to_state_ids = dict(res) + + event_dict = yield self._get_events( + [ + item + for lst in event_to_state_ids.values() + for item in lst + ], + get_prev_content=False + ).addCallback( + lambda evs: {ev.event_id: ev for ev in evs} ) - state_dict = { - group: { + event_to_state = { + event_id: { (ev.type, ev.state_key): ev - for ev in state + for ev in [ + event_dict[state_id] + for state_id in state_ids + if state_id in event_dict + ] } - for group, state in state_list + for event_id, state_ids in event_to_state_ids.items() } defer.returnValue([ - state_dict.get(event_to_group.get(event, None), None) + event_to_state[event] for event in event_ids ]) -- cgit 1.4.1 From 07507643cb6a2fde1a87d229f8d77525627a0632 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2015 15:06:51 +0100 Subject: Use dictionary cache to do group -> state fetching --- synapse/handlers/federation.py | 2 +- synapse/state.py | 10 +- synapse/storage/_base.py | 39 +++++--- synapse/storage/state.py | 191 ++++++++++++++++++++++++++------------- synapse/storage/stream.py | 3 +- synapse/util/dictionary_cache.py | 58 +++++++----- tests/test_state.py | 2 +- 7 files changed, 195 insertions(+), 110 deletions(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 22f534e49a..90649af9e1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -507,7 +507,7 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) states = yield defer.gatherResults([ - self.state_handler.resolve_state_groups([e]) + self.state_handler.resolve_state_groups(room_id, [e]) for e in event_ids ]) states = dict(zip(event_ids, [s[1] for s in states])) diff --git a/synapse/state.py b/synapse/state.py index 80da90a72c..b5e5d7bbda 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -96,7 +96,7 @@ class StateHandler(object): cache.ts = self.clock.time_msec() state = cache.state else: - res = yield self.resolve_state_groups(event_ids) + res = yield self.resolve_state_groups(room_id, event_ids) state = res[1] if event_type: @@ -155,13 +155,13 @@ class StateHandler(object): if event.is_state(): ret = yield self.resolve_state_groups( - [e for e, _ in event.prev_events], + event.room_id, [e for e, _ in event.prev_events], event_type=event.type, state_key=event.state_key, ) else: ret = yield self.resolve_state_groups( - [e for e, _ in event.prev_events], + event.room_id, [e for e, _ in event.prev_events], ) group, curr_state, prev_state = ret @@ -180,7 +180,7 @@ class StateHandler(object): @defer.inlineCallbacks @log_function - def resolve_state_groups(self, event_ids, event_type=None, state_key=""): + def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -205,7 +205,7 @@ class StateHandler(object): ) state_groups = yield self.store.get_state_groups( - event_ids + room_id, event_ids ) logger.debug( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 7b76ee3b73..803b9d599d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -18,6 +18,7 @@ from synapse.api.errors import StoreError from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache +from synapse.util.dictionary_cache import DictionaryCache import synapse.metrics from util.id_generators import IdGenerator, StreamIdGenerator @@ -87,23 +88,33 @@ class Cache(object): ) def get(self, *keyargs): - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + try: + if len(keyargs) != self.keylen: + raise ValueError("Expected a key to have %d items", self.keylen) - val = self.cache.get(keyargs, self.sentinel) - if val is not self.sentinel: - cache_counter.inc_hits(self.name) - return val + val = self.cache.get(keyargs, self.sentinel) + if val is not self.sentinel: + cache_counter.inc_hits(self.name) + return val - cache_counter.inc_misses(self.name) - raise KeyError() + cache_counter.inc_misses(self.name) + raise KeyError() + except KeyError: + raise + except: + logger.exception("Cache.get failed for %s" % (self.name,)) + raise def update(self, sequence, *args): - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(*args) + try: + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + self.prefill(*args) + except: + logger.exception("Cache.update failed for %s" % (self.name,)) + raise def prefill(self, *args): # because I can't *keyargs, value keyargs = args[:-1] @@ -327,6 +338,8 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) + self._state_group_cache = DictionaryCache("*stateGroupCache*", 100000) + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 91a5ae86a4..a967b3d44b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -45,52 +45,38 @@ class StateStore(SQLBaseStore): """ @defer.inlineCallbacks - def get_state_groups(self, event_ids): + def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids The return value is a dict mapping group names to lists of events. """ - def f(txn): - groups = set() - for event_id in event_ids: - group = self._simple_select_one_onecol_txn( - txn, - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - ) - if group: - groups.add(group) - - res = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", - ) - - res[group] = state_ids + event_and_groups = yield defer.gatherResults( + [ + self._get_state_group_for_event( + room_id, event_id, + ).addCallback(lambda group, event_id: (event_id, group), event_id) + for event_id in event_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) - return res + groups = set(group for _, group in event_and_groups if group) - states = yield self.runInteraction( - "get_state_groups", - f, - ) - - state_list = yield defer.gatherResults( + group_to_state = yield defer.gatherResults( [ - self._fetch_events_for_group(group, vals) - for group, vals in states.items() + self._get_state_for_group( + group, + ).addCallback(lambda state_dict, group: (group, state_dict), group) + for group in groups ], consumeErrors=True, - ) + ).addErrback(unwrapFirstError) - defer.returnValue(dict(state_list)) + defer.returnValue({ + group: state_map.values() + for group, state_map in group_to_state + }) @cached(num_args=1) def _fetch_events_for_group(self, key, events): @@ -207,16 +193,25 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @cached(num_args=3, lru=True) - def _get_state_groups_from_group(self, room_id, group, types): + @cached(num_args=2, lru=True, max_entries=10000) + def _get_state_groups_from_group(self, group, types): def f(txn): + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), + ) + else: + where_clause = "" + sql = ( "SELECT event_id FROM state_groups_state WHERE" - " room_id = ? AND state_group = ? AND (%s)" - ) % (" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),) + " state_group = ? %s" + ) % (where_clause,) + + args = [group] + if types is not None: + args.extend([i for typ in types for i in typ]) - args = [room_id, group] - args.extend([i for typ in types for i in typ]) txn.execute(sql, args) return group, [ @@ -229,7 +224,7 @@ class StateStore(SQLBaseStore): f, ) - @cached(num_args=3, lru=True, max_entries=100000) + @cached(num_args=3, lru=True, max_entries=20000) def _get_state_for_event_id(self, room_id, event_id, types): def f(txn): type_and_state_sql = " OR ".join([ @@ -280,40 +275,33 @@ class StateStore(SQLBaseStore): deferred: A list of dicts corresponding to the event_ids given. The dicts are mappings from (type, state_key) -> state_events """ - set_types = frozenset(types) - res = yield defer.gatherResults( + event_and_groups = yield defer.gatherResults( [ - self._get_state_for_event_id( - room_id, event_id, set_types, - ) + self._get_state_group_for_event( + room_id, event_id, + ).addCallback(lambda group, event_id: (event_id, group), event_id) for event_id in event_ids ], consumeErrors=True, ).addErrback(unwrapFirstError) - event_to_state_ids = dict(res) + groups = set(group for _, group in event_and_groups) - event_dict = yield self._get_events( + res = yield defer.gatherResults( [ - item - for lst in event_to_state_ids.values() - for item in lst + self._get_state_for_group( + group, types + ).addCallback(lambda state_dict, group: (group, state_dict), group) + for group in groups ], - get_prev_content=False - ).addCallback( - lambda evs: {ev.event_id: ev for ev in evs} - ) + consumeErrors=True, + ).addErrback(unwrapFirstError) + + group_to_state = dict(res) event_to_state = { - event_id: { - (ev.type, ev.state_key): ev - for ev in [ - event_dict[state_id] - for state_id in state_ids - if state_id in event_dict - ] - } - for event_id, state_ids in event_to_state_ids.items() + event_id: group_to_state[group] + for event_id, group in event_and_groups } defer.returnValue([ @@ -321,6 +309,79 @@ class StateStore(SQLBaseStore): for event in event_ids ]) + @cached(num_args=2, lru=True, max_entries=100000) + def _get_state_group_for_event(self, room_id, event_id): + return self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={ + "event_id": event_id, + }, + retcol="state_group", + allow_none=True, + desc="_get_state_group_for_event", + ) + + @defer.inlineCallbacks + def _get_state_for_group(self, group, types=None): + is_all, state_dict = self._state_group_cache.get(group) + + type_to_key = {} + missing_types = set() + if types is not None: + for typ, state_key in types: + if state_key is None: + type_to_key[typ] = None + missing_types.add((typ, state_key)) + else: + if type_to_key.get(typ, object()) is not None: + type_to_key.setdefault(typ, set()).add(state_key) + + if (typ, state_key) not in state_dict: + missing_types.add((typ, state_key)) + + if is_all and types is None: + defer.returnValue(state_dict) + + if is_all or (types is not None and not missing_types): + def include(typ, state_key): + sentinel = object() + valid_state_keys = type_to_key.get(typ, sentinel) + if valid_state_keys is sentinel: + return False + if valid_state_keys is None: + return True + if state_key in valid_state_keys: + return True + return False + + defer.returnValue({ + k: v + for k, v in state_dict.items() + if include(k[0], k[1]) + }) + + # Okay, so we have some missing_types, lets fetch them. + cache_seq_num = self._state_group_cache.sequence + _, state_ids = yield self._get_state_groups_from_group( + group, + frozenset(types) if types else None + ) + state_events = yield self._get_events(state_ids, get_prev_content=False) + state_dict = { + (e.type, e.state_key): e + for e in state_events + } + + # Update the cache + self._state_group_cache.update( + cache_seq_num, + key=group, + value=state_dict, + full=(types is None), + ) + + defer.returnValue(state_dict) + def _make_group_id(clock): return str(int(clock.time_msec())) + random_string(5) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index af45fc5619..9db259d5fc 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -300,8 +300,7 @@ class StreamStore(SQLBaseStore): defer.returnValue((events, token)) @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token, - with_feedback=False, from_token=None): + def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): # TODO (erikj): Handle compressed feedback end_token = RoomStreamToken.parse_stream_token(end_token) diff --git a/synapse/util/dictionary_cache.py b/synapse/util/dictionary_cache.py index 0877cc79f6..38b131677c 100644 --- a/synapse/util/dictionary_cache.py +++ b/synapse/util/dictionary_cache.py @@ -16,6 +16,10 @@ from synapse.util.lrucache import LruCache from collections import namedtuple import threading +import logging + + +logger = logging.getLogger(__name__) DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) @@ -47,21 +51,25 @@ class DictionaryCache(object): ) def get(self, key, dict_keys=None): - entry = self.cache.get(key, self.sentinel) - if entry is not self.sentinel: - # cache_counter.inc_hits(self.name) - - if dict_keys is None: - return DictionaryEntry(entry.full, dict(entry.value)) - else: - return DictionaryEntry(entry.full, { - k: entry.value[k] - for k in dict_keys - if k in entry.value - }) - - # cache_counter.inc_misses(self.name) - return DictionaryEntry(False, {}) + try: + entry = self.cache.get(key, self.sentinel) + if entry is not self.sentinel: + # cache_counter.inc_hits(self.name) + + if dict_keys is None: + return DictionaryEntry(entry.full, dict(entry.value)) + else: + return DictionaryEntry(entry.full, { + k: entry.value[k] + for k in dict_keys + if k in entry.value + }) + + # cache_counter.inc_misses(self.name) + return DictionaryEntry(False, {}) + except: + logger.exception("get failed") + raise def invalidate(self, key): self.check_thread() @@ -77,14 +85,18 @@ class DictionaryCache(object): self.cache.clear() def update(self, sequence, key, value, full=False): - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - if full: - self._insert(key, value) - else: - self._update_or_insert(key, value) + try: + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + if full: + self._insert(key, value) + else: + self._update_or_insert(key, value) + except: + logger.exception("update failed") + raise def _update_or_insert(self, key, value): entry = self.cache.setdefault(key, DictionaryEntry(False, {})) diff --git a/tests/test_state.py b/tests/test_state.py index fea25f7021..5845358754 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -69,7 +69,7 @@ class StateGroupStore(object): self._next_group = 1 - def get_state_groups(self, event_ids): + def get_state_groups(self, room_id, event_ids): groups = {} for event_id in event_ids: group = self._event_to_state_group.get(event_id) -- cgit 1.4.1 From a89559d79714b8dd07ec5a73f2629d003fcad92c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2015 15:39:47 +0100 Subject: Use LRU cache by default --- synapse/storage/_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8f812f0fd7..2a601e37e3 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -57,7 +57,7 @@ cache_counter = metrics.register_cache( class Cache(object): - def __init__(self, name, max_entries=1000, keylen=1, lru=False): + def __init__(self, name, max_entries=1000, keylen=1, lru=True): if lru: self.cache = LruCache(max_size=max_entries) self.max_entries = None @@ -141,7 +141,7 @@ class CacheDescriptor(object): which can be used to insert values into the cache specifically, without calling the calculation function. """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=False): + def __init__(self, orig, max_entries=1000, num_args=1, lru=True): self.orig = orig self.max_entries = max_entries @@ -192,7 +192,7 @@ class CacheDescriptor(object): return wrapped -def cached(max_entries=1000, num_args=1, lru=False): +def cached(max_entries=1000, num_args=1, lru=True): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, -- cgit 1.4.1 From 7eea3e356ff58168f3525879a8eb684f0681ee68 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 6 Aug 2015 13:33:34 +0100 Subject: Make @cached cache deferreds rather than the deferreds' values --- synapse/storage/_base.py | 21 ++++++++------------- synapse/util/async.py | 9 +++++++-- tests/storage/test__base.py | 11 +++++++---- 3 files changed, 22 insertions(+), 19 deletions(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f1265541ba..8604d38c3e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,6 +15,7 @@ import logging from synapse.api.errors import StoreError +from synapse.util.async import ObservableDeferred from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache @@ -173,33 +174,27 @@ class CacheDescriptor(object): ) @functools.wraps(self.orig) - @defer.inlineCallbacks def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] try: cached_result = cache.get(*keyargs) - if DEBUG_CACHES: - actual_result = yield self.function_to_call(obj, *args, **kwargs) - if actual_result != cached_result: - logger.error( - "Stale cache entry %s%r: cached: %r, actual %r", - self.orig.__name__, keyargs, - cached_result, actual_result, - ) - raise ValueError("Stale cache entry") - defer.returnValue(cached_result) + return cached_result.observe() except KeyError: # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) sequence = cache.sequence - ret = yield self.function_to_call(obj, *args, **kwargs) + ret = defer.maybeDeferred( + self.function_to_call, + obj, *args, **kwargs + ) + ret = ObservableDeferred(ret, consumeErrors=False) cache.update(sequence, *(keyargs + [ret])) - defer.returnValue(ret) + return ret.observe() wrapped.invalidate = cache.invalidate wrapped.invalidate_all = cache.invalidate_all diff --git a/synapse/util/async.py b/synapse/util/async.py index 5a1d545c96..7bf2d38bb8 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -51,7 +51,7 @@ class ObservableDeferred(object): object.__setattr__(self, "_observers", set()) def callback(r): - self._result = (True, r) + object.__setattr__(self, "_result", (True, r)) while self._observers: try: self._observers.pop().callback(r) @@ -60,7 +60,7 @@ class ObservableDeferred(object): return r def errback(f): - self._result = (False, f) + object.__setattr__(self, "_result", (False, f)) while self._observers: try: self._observers.pop().errback(f) @@ -97,3 +97,8 @@ class ObservableDeferred(object): def __setattr__(self, name, value): setattr(self._deferred, name, value) + + def __repr__(self): + return "" % ( + id(self), self._result, self._deferred, + ) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 8c3d2952bd..8fa305d18a 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -17,6 +17,8 @@ from tests import unittest from twisted.internet import defer +from synapse.util.async import ObservableDeferred + from synapse.storage._base import Cache, cached @@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertTrue(callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])) - @defer.inlineCallbacks def test_prefill(self): callcount = [0] + d = defer.succeed(123) + class A(object): @cached() def func(self, key): callcount[0] += 1 - return key + return d a = A() - a.func.prefill("foo", 123) + a.func.prefill("foo", ObservableDeferred(d)) - self.assertEquals((yield a.func("foo")), 123) + self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) -- cgit 1.4.1 From 433314cc34295ffefca349b9fc6914d81f2521e0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 6 Aug 2015 14:01:05 +0100 Subject: Re-implement DEBUG_CACHES flag --- synapse/storage/_base.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 7b2be6745e..8384299a13 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -178,8 +178,23 @@ class CacheDescriptor(object): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] try: - cached_result = cache.get(*keyargs) - return cached_result.observe() + cached_result_d = cache.get(*keyargs) + + if DEBUG_CACHES: + + @defer.inlineCallbacks + def check_result(cached_result): + actual_result = yield self.function_to_call(obj, *args, **kwargs) + if actual_result != cached_result: + logger.error( + "Stale cache entry %s%r: cached: %r, actual %r", + self.orig.__name__, keyargs, + cached_result, actual_result, + ) + raise ValueError("Stale cache entry") + cached_result_d.observe().addCallback(check_result) + + return cached_result_d.observe() except KeyError: # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated -- cgit 1.4.1 From b811c9857491c0569ae367708721fbbaebf3adab Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 6 Aug 2015 14:01:27 +0100 Subject: Remove failed deferreds from cache --- synapse/storage/_base.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8384299a13..99c0948754 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -205,8 +205,14 @@ class CacheDescriptor(object): self.function_to_call, obj, *args, **kwargs ) - ret = ObservableDeferred(ret, consumeErrors=False) + def onErr(f): + cache.invalidate(*keyargs) + return f + + ret.addErrback(onErr) + + ret = ObservableDeferred(ret, consumeErrors=False) cache.update(sequence, *(keyargs + [ret])) return ret.observe() -- cgit 1.4.1 From 63b1eaf32c57f0e1bae9b6a3768a560165637aee Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 6 Aug 2015 14:02:50 +0100 Subject: Docs --- synapse/storage/_base.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 99c0948754..0872a438f1 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -132,6 +132,9 @@ class Cache(object): class CacheDescriptor(object): """ A method decorator that applies a memoizing cache around the function. + This caches deferreds, rather than the results themselves. Deferreds that + fail are removed from the cache. + The function is presumed to take zero or more arguments, which are used in a tuple as the key for the cache. Hits are served directly from the cache; misses use the function body to generate the value. -- cgit 1.4.1 From b8e386db59eea6a59b8338acfd8ea42632d539be Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2015 11:52:21 +0100 Subject: Change Cache to not use *args in its interface --- synapse/storage/__init__.py | 4 +- synapse/storage/_base.py | 85 +++++++++++++++---------------------- synapse/storage/directory.py | 4 +- synapse/storage/event_federation.py | 4 +- synapse/storage/events.py | 21 ++++----- synapse/storage/keys.py | 2 +- synapse/storage/presence.py | 4 +- synapse/storage/push_rule.py | 16 +++---- synapse/storage/registration.py | 2 +- synapse/storage/roommember.py | 6 +-- tests/storage/test__base.py | 12 +++--- 11 files changed, 73 insertions(+), 87 deletions(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 71d5d92500..1a6a8a3762 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore, key = (user.to_string(), access_token, device_id, ip) try: - last_seen = self.client_ip_last_seen.get(*key) + last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None @@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore, if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: defer.returnValue(None) - self.client_ip_last_seen.prefill(*key + (now,)) + self.client_ip_last_seen.prefill(key, now) # It's safe not to lock here: a) no unique constraint, # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index d4751769e4..32089b05e5 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -58,6 +58,9 @@ cache_counter = metrics.register_cache( ) +_CacheSentinel = object() + + class Cache(object): def __init__(self, name, max_entries=1000, keylen=1, lru=True): @@ -74,11 +77,6 @@ class Cache(object): self.thread = None caches_by_name[name] = self.cache - class Sentinel(object): - __slots__ = [] - - self.sentinel = Sentinel() - def check_thread(self): expected_thread = self.thread if expected_thread is None: @@ -89,52 +87,38 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, *keyargs): - try: - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + def get(self, keyargs, default=_CacheSentinel): + val = self.cache.get(keyargs, _CacheSentinel) + if val is not _CacheSentinel: + cache_counter.inc_hits(self.name) + return val - val = self.cache.get(keyargs, self.sentinel) - if val is not self.sentinel: - cache_counter.inc_hits(self.name) - return val + cache_counter.inc_misses(self.name) - cache_counter.inc_misses(self.name) + if default is _CacheSentinel: raise KeyError() - except KeyError: - raise - except: - logger.exception("Cache.get failed for %s" % (self.name,)) - raise - - def update(self, sequence, *args): - try: - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(*args) - except: - logger.exception("Cache.update failed for %s" % (self.name,)) - raise - - def prefill(self, *args): # because I can't *keyargs, value - keyargs = args[:-1] - value = args[-1] + else: + return default - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + def update(self, sequence, keyargs, value): + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + self.prefill(keyargs, value) + def prefill(self, keyargs, value): if self.max_entries is not None: while len(self.cache) >= self.max_entries: self.cache.popitem(last=False) self.cache[keyargs] = value - def invalidate(self, *keyargs): + def invalidate(self, keyargs): self.check_thread() - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + if not isinstance(keyargs, tuple): + raise ValueError("keyargs must be a tuple.") + # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 @@ -185,20 +169,21 @@ class CacheDescriptor(object): % (orig.__name__,) ) - def __get__(self, obj, objtype=None): - cache = Cache( + self.cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, lru=self.lru, ) + def __get__(self, obj, objtype=None): + @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] + keyargs = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = cache.get(*keyargs) + cached_result_d = self.cache.get(keyargs) if DEBUG_CACHES: @@ -219,7 +204,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = cache.sequence + sequence = self.cache.sequence ret = defer.maybeDeferred( self.function_to_call, @@ -227,19 +212,19 @@ class CacheDescriptor(object): ) def onErr(f): - cache.invalidate(*keyargs) + self.cache.invalidate(keyargs) return f ret.addErrback(onErr) - ret = ObservableDeferred(ret, consumeErrors=False) - cache.update(sequence, *(keyargs + [ret])) + ret = ObservableDeferred(ret, consumeErrors=True) + self.cache.update(sequence, keyargs, ret) return ret.observe() - wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all - wrapped.prefill = cache.prefill + wrapped.invalidate = self.cache.invalidate + wrapped.invalidate_all = self.cache.invalidate_all + wrapped.prefill = self.cache.prefill obj.__dict__[self.orig.__name__] = wrapped diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 2b2bdf8615..f3947bbe89 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore): }, desc="create_room_alias_association", ) - self.get_aliases_for_room.invalidate(room_id) + self.get_aliases_for_room.invalidate((room_id,)) @defer.inlineCallbacks def delete_room_alias(self, room_alias): @@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore): room_alias, ) - self.get_aliases_for_room.invalidate(room_id) + self.get_aliases_for_room.invalidate((room_id,)) defer.returnValue(room_id) def _delete_room_alias_txn(self, txn, room_alias): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 45b86c94e8..910b6598a7 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -362,7 +362,7 @@ class EventFederationStore(SQLBaseStore): for room_id in events_by_room: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, room_id + self.get_latest_event_ids_in_room.invalidate, (room_id,) ) def get_backfill_events(self, room_id, event_list, limit): @@ -505,4 +505,4 @@ class EventFederationStore(SQLBaseStore): query = "DELETE FROM event_forward_extremities WHERE room_id = ?" txn.execute(query, (room_id,)) - txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id) + txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ed7ea38804..5b64918024 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore): if current_state: txn.call_after(self.get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) - txn.call_after(self.get_users_in_room.invalidate, event.room_id) - txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) + txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) + txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_room_name_and_aliases, event.room_id) self._simple_delete_txn( @@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore): if not context.rejected: txn.call_after( self.get_current_state_for_key.invalidate, - event.room_id, event.type, event.state_key - ) + (event.room_id, event.type, event.state_key,) + ) if event.type in [EventTypes.Name, EventTypes.Aliases]: txn.call_after( self.get_room_name_and_aliases.invalidate, - event.room_id + (event.room_id,) ) self._simple_upsert_txn( @@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore): def _invalidate_get_event_cache(self, event_id): for check_redacted in (False, True): for get_prev_content in (False, True): - self._get_event_cache.invalidate(event_id, check_redacted, - get_prev_content) + self._get_event_cache.invalidate( + (event_id, check_redacted, get_prev_content) + ) def _get_event_txn(self, txn, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False): @@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore): for event_id in events: try: ret = self._get_event_cache.get( - event_id, check_redacted, get_prev_content + (event_id, check_redacted, get_prev_content,) ) if allow_rejected or not ret.rejected_reason: @@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore): ev.unsigned["prev_content"] = prev.get_dict()["content"] self._get_event_cache.prefill( - ev.event_id, check_redacted, get_prev_content, ev + (ev.event_id, check_redacted, get_prev_content), ev ) defer.returnValue(ev) @@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore): ev.unsigned["prev_content"] = prev.get_dict()["content"] self._get_event_cache.prefill( - ev.event_id, check_redacted, get_prev_content, ev + (ev.event_id, check_redacted, get_prev_content), ev ) return ev diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index e3f98f0cde..49b8e37cfd 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -131,7 +131,7 @@ class KeyStore(SQLBaseStore): desc="store_server_verify_key", ) - self.get_all_server_verify_keys.invalidate(server_name) + self.get_all_server_verify_keys.invalidate((server_name,)) def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index fefcf6bce0..576cf670cc 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore): updatevalues={"accepted": True}, desc="set_presence_list_accepted", ) - self.get_presence_list_accepted.invalidate(observer_localpart) + self.get_presence_list_accepted.invalidate((observer_localpart,)) defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): @@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore): "observed_user_id": observed_userid}, desc="del_presence_list", ) - self.get_presence_list_accepted.invalidate(observer_localpart) + self.get_presence_list_accepted.invalidate((observer_localpart,)) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index a220f3632e..9b88ca7b39 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -151,11 +151,11 @@ class PushRuleStore(SQLBaseStore): txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) self._simple_insert_txn( @@ -187,10 +187,10 @@ class PushRuleStore(SQLBaseStore): new_rule['priority'] = new_prio txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) self._simple_insert_txn( @@ -216,8 +216,8 @@ class PushRuleStore(SQLBaseStore): desc="delete_push_rule", ) - self.get_push_rules_for_user.invalidate(user_name) - self.get_push_rules_enabled_for_user.invalidate(user_name) + self.get_push_rules_for_user.invalidate((user_name,)) + self.get_push_rules_enabled_for_user.invalidate((user_name,)) @defer.inlineCallbacks def set_push_rule_enabled(self, user_name, rule_id, enabled): @@ -238,10 +238,10 @@ class PushRuleStore(SQLBaseStore): {'id': new_id}, ) txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 90e2606be2..4eaa088b36 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore): user_id ) for r in rows: - self.get_user_by_token.invalidate(r) + self.get_user_by_token.invalidate((r,)) @cached() def get_user_by_token(self, token): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 55dd3f6cfb..9f14f38f24 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -54,9 +54,9 @@ class RoomMemberStore(SQLBaseStore): ) for event in events: - txn.call_after(self.get_rooms_for_user.invalidate, event.state_key) - txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) - txn.call_after(self.get_users_in_room.invalidate, event.room_id) + txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) + txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) + txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 8fa305d18a..abee2f631d 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -42,12 +42,12 @@ class CacheTestCase(unittest.TestCase): self.assertEquals(self.cache.get("foo"), 123) def test_invalidate(self): - self.cache.prefill("foo", 123) - self.cache.invalidate("foo") + self.cache.prefill(("foo",), 123) + self.cache.invalidate(("foo",)) failed = False try: - self.cache.get("foo") + self.cache.get(("foo",)) except KeyError: failed = True @@ -141,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(callcount[0], 1) - a.func.invalidate("foo") + a.func.invalidate(("foo",)) yield a.func("foo") @@ -153,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase): def func(self, key): return key - A().func.invalidate("what") + A().func.invalidate(("what",)) @defer.inlineCallbacks def test_max_entries(self): @@ -193,7 +193,7 @@ class CacheDecoratorTestCase(unittest.TestCase): a = A() - a.func.prefill("foo", ObservableDeferred(d)) + a.func.prefill(("foo",), ObservableDeferred(d)) self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) -- cgit 1.4.1 From 02118901347ab5067fbb47169b36a8424e671bfb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2015 18:14:49 +0100 Subject: Implement a CacheListDescriptor --- synapse/storage/_base.py | 106 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 32089b05e5..556aa3b523 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -16,6 +16,7 @@ import logging from synapse.api.errors import StoreError from synapse.util.async import ObservableDeferred +from synapse.util import unwrapFirstError from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache @@ -231,6 +232,101 @@ class CacheDescriptor(object): return wrapped +class CacheListDescriptor(object): + def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + + self.num_args = num_args + self.list_name = list_name + + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + self.list_pos = self.arg_names.index(self.list_name) + + self.cache = cache + + self.sentinel = object() + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + + if self.list_name not in self.arg_names: + raise Exception( + "Couldn't see arguments %r for %r." + % (self.list_name, cache.name,) + ) + + def __get__(self, obj, objtype=None): + + @functools.wraps(self.orig) + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] + list_args = arg_dict[self.list_name] + + cached = {} + missing = [] + for arg in list_args: + key = list(keyargs) + key[self.list_pos] = arg + + try: + res = self.cache.get(tuple(key)).observe() + res.addCallback(lambda r, arg: (arg, r), arg) + cached[arg] = res + except KeyError: + missing.append(arg) + + if missing: + sequence = self.cache.sequence + args_to_call = dict(arg_dict) + args_to_call[self.list_name] = missing + + ret_d = defer.maybeDeferred( + self.function_to_call, + **args_to_call + ) + + ret_d = ObservableDeferred(ret_d) + + for arg in missing: + observer = ret_d.observe() + observer.addCallback(lambda r, arg: r[arg], arg) + + observer = ObservableDeferred(observer) + + key = list(keyargs) + key[self.list_pos] = arg + self.cache.update(sequence, tuple(key), observer) + + def invalidate(f, key): + self.cache.invalidate(key) + return f + observer.addErrback(invalidate, tuple(key)) + + res = observer.observe() + res.addCallback(lambda r, arg: (arg, r), arg) + + cached[arg] = res + + return defer.gatherResults( + cached.values(), + consumeErrors=True, + ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) + + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + def cached(max_entries=1000, num_args=1, lru=True): return lambda orig: CacheDescriptor( orig, @@ -250,6 +346,16 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): ) +def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): + return lambda orig: CacheListDescriptor( + orig, + cache=cache, + list_name=list_name, + num_args=num_args, + inlineCallbacks=inlineCallbacks, + ) + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() -- cgit 1.4.1 From 20addfa358e4f72b9f0c25d48c1e3ecfc08a68b2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2015 11:52:21 +0100 Subject: Change Cache to not use *args in its interface --- synapse/storage/__init__.py | 4 +-- synapse/storage/_base.py | 61 +++++++++++++++++++------------------ synapse/storage/directory.py | 4 +-- synapse/storage/event_federation.py | 4 +-- synapse/storage/events.py | 21 +++++++------ synapse/storage/keys.py | 2 +- synapse/storage/presence.py | 4 +-- synapse/storage/push_rule.py | 16 +++++----- synapse/storage/registration.py | 2 +- synapse/storage/roommember.py | 6 ++-- tests/storage/test__base.py | 12 ++++---- 11 files changed, 69 insertions(+), 67 deletions(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 71d5d92500..1a6a8a3762 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore, key = (user.to_string(), access_token, device_id, ip) try: - last_seen = self.client_ip_last_seen.get(*key) + last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None @@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore, if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: defer.returnValue(None) - self.client_ip_last_seen.prefill(*key + (now,)) + self.client_ip_last_seen.prefill(key, now) # It's safe not to lock here: a) no unique constraint, # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 0872a438f1..e07cf3b58a 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -57,6 +57,9 @@ cache_counter = metrics.register_cache( ) +_CacheSentinel = object() + + class Cache(object): def __init__(self, name, max_entries=1000, keylen=1, lru=True): @@ -83,41 +86,38 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, *keyargs): - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) - - if keyargs in self.cache: + def get(self, keyargs, default=_CacheSentinel): + val = self.cache.get(keyargs, _CacheSentinel) + if val is not _CacheSentinel: cache_counter.inc_hits(self.name) - return self.cache[keyargs] + return val cache_counter.inc_misses(self.name) - raise KeyError() - def update(self, sequence, *args): + if default is _CacheSentinel: + raise KeyError() + else: + return default + + def update(self, sequence, keyargs, value): self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) - self.prefill(*args) - - def prefill(self, *args): # because I can't *keyargs, value - keyargs = args[:-1] - value = args[-1] - - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + self.prefill(keyargs, value) + def prefill(self, keyargs, value): if self.max_entries is not None: while len(self.cache) >= self.max_entries: self.cache.popitem(last=False) self.cache[keyargs] = value - def invalidate(self, *keyargs): + def invalidate(self, keyargs): self.check_thread() - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + if not isinstance(keyargs, tuple): + raise ValueError("keyargs must be a tuple.") + # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 @@ -168,20 +168,21 @@ class CacheDescriptor(object): % (orig.__name__,) ) - def __get__(self, obj, objtype=None): - cache = Cache( + self.cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, lru=self.lru, ) + def __get__(self, obj, objtype=None): + @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] + keyargs = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = cache.get(*keyargs) + cached_result_d = self.cache.get(keyargs) if DEBUG_CACHES: @@ -202,7 +203,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = cache.sequence + sequence = self.cache.sequence ret = defer.maybeDeferred( self.function_to_call, @@ -210,19 +211,19 @@ class CacheDescriptor(object): ) def onErr(f): - cache.invalidate(*keyargs) + self.cache.invalidate(keyargs) return f ret.addErrback(onErr) - ret = ObservableDeferred(ret, consumeErrors=False) - cache.update(sequence, *(keyargs + [ret])) + ret = ObservableDeferred(ret, consumeErrors=True) + self.cache.update(sequence, keyargs, ret) return ret.observe() - wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all - wrapped.prefill = cache.prefill + wrapped.invalidate = self.cache.invalidate + wrapped.invalidate_all = self.cache.invalidate_all + wrapped.prefill = self.cache.prefill obj.__dict__[self.orig.__name__] = wrapped diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 2b2bdf8615..f3947bbe89 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore): }, desc="create_room_alias_association", ) - self.get_aliases_for_room.invalidate(room_id) + self.get_aliases_for_room.invalidate((room_id,)) @defer.inlineCallbacks def delete_room_alias(self, room_alias): @@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore): room_alias, ) - self.get_aliases_for_room.invalidate(room_id) + self.get_aliases_for_room.invalidate((room_id,)) defer.returnValue(room_id) def _delete_room_alias_txn(self, txn, room_alias): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 45b86c94e8..910b6598a7 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -362,7 +362,7 @@ class EventFederationStore(SQLBaseStore): for room_id in events_by_room: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, room_id + self.get_latest_event_ids_in_room.invalidate, (room_id,) ) def get_backfill_events(self, room_id, event_list, limit): @@ -505,4 +505,4 @@ class EventFederationStore(SQLBaseStore): query = "DELETE FROM event_forward_extremities WHERE room_id = ?" txn.execute(query, (room_id,)) - txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id) + txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ed7ea38804..5b64918024 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore): if current_state: txn.call_after(self.get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) - txn.call_after(self.get_users_in_room.invalidate, event.room_id) - txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) + txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) + txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_room_name_and_aliases, event.room_id) self._simple_delete_txn( @@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore): if not context.rejected: txn.call_after( self.get_current_state_for_key.invalidate, - event.room_id, event.type, event.state_key - ) + (event.room_id, event.type, event.state_key,) + ) if event.type in [EventTypes.Name, EventTypes.Aliases]: txn.call_after( self.get_room_name_and_aliases.invalidate, - event.room_id + (event.room_id,) ) self._simple_upsert_txn( @@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore): def _invalidate_get_event_cache(self, event_id): for check_redacted in (False, True): for get_prev_content in (False, True): - self._get_event_cache.invalidate(event_id, check_redacted, - get_prev_content) + self._get_event_cache.invalidate( + (event_id, check_redacted, get_prev_content) + ) def _get_event_txn(self, txn, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False): @@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore): for event_id in events: try: ret = self._get_event_cache.get( - event_id, check_redacted, get_prev_content + (event_id, check_redacted, get_prev_content,) ) if allow_rejected or not ret.rejected_reason: @@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore): ev.unsigned["prev_content"] = prev.get_dict()["content"] self._get_event_cache.prefill( - ev.event_id, check_redacted, get_prev_content, ev + (ev.event_id, check_redacted, get_prev_content), ev ) defer.returnValue(ev) @@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore): ev.unsigned["prev_content"] = prev.get_dict()["content"] self._get_event_cache.prefill( - ev.event_id, check_redacted, get_prev_content, ev + (ev.event_id, check_redacted, get_prev_content), ev ) return ev diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index e3f98f0cde..49b8e37cfd 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -131,7 +131,7 @@ class KeyStore(SQLBaseStore): desc="store_server_verify_key", ) - self.get_all_server_verify_keys.invalidate(server_name) + self.get_all_server_verify_keys.invalidate((server_name,)) def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index fefcf6bce0..576cf670cc 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore): updatevalues={"accepted": True}, desc="set_presence_list_accepted", ) - self.get_presence_list_accepted.invalidate(observer_localpart) + self.get_presence_list_accepted.invalidate((observer_localpart,)) defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): @@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore): "observed_user_id": observed_userid}, desc="del_presence_list", ) - self.get_presence_list_accepted.invalidate(observer_localpart) + self.get_presence_list_accepted.invalidate((observer_localpart,)) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index a220f3632e..9b88ca7b39 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -151,11 +151,11 @@ class PushRuleStore(SQLBaseStore): txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) self._simple_insert_txn( @@ -187,10 +187,10 @@ class PushRuleStore(SQLBaseStore): new_rule['priority'] = new_prio txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) self._simple_insert_txn( @@ -216,8 +216,8 @@ class PushRuleStore(SQLBaseStore): desc="delete_push_rule", ) - self.get_push_rules_for_user.invalidate(user_name) - self.get_push_rules_enabled_for_user.invalidate(user_name) + self.get_push_rules_for_user.invalidate((user_name,)) + self.get_push_rules_enabled_for_user.invalidate((user_name,)) @defer.inlineCallbacks def set_push_rule_enabled(self, user_name, rule_id, enabled): @@ -238,10 +238,10 @@ class PushRuleStore(SQLBaseStore): {'id': new_id}, ) txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 90e2606be2..4eaa088b36 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore): user_id ) for r in rows: - self.get_user_by_token.invalidate(r) + self.get_user_by_token.invalidate((r,)) @cached() def get_user_by_token(self, token): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 55dd3f6cfb..9f14f38f24 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -54,9 +54,9 @@ class RoomMemberStore(SQLBaseStore): ) for event in events: - txn.call_after(self.get_rooms_for_user.invalidate, event.state_key) - txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) - txn.call_after(self.get_users_in_room.invalidate, event.room_id) + txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) + txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) + txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 8fa305d18a..abee2f631d 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -42,12 +42,12 @@ class CacheTestCase(unittest.TestCase): self.assertEquals(self.cache.get("foo"), 123) def test_invalidate(self): - self.cache.prefill("foo", 123) - self.cache.invalidate("foo") + self.cache.prefill(("foo",), 123) + self.cache.invalidate(("foo",)) failed = False try: - self.cache.get("foo") + self.cache.get(("foo",)) except KeyError: failed = True @@ -141,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(callcount[0], 1) - a.func.invalidate("foo") + a.func.invalidate(("foo",)) yield a.func("foo") @@ -153,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase): def func(self, key): return key - A().func.invalidate("what") + A().func.invalidate(("what",)) @defer.inlineCallbacks def test_max_entries(self): @@ -193,7 +193,7 @@ class CacheDecoratorTestCase(unittest.TestCase): a = A() - a.func.prefill("foo", ObservableDeferred(d)) + a.func.prefill(("foo",), ObservableDeferred(d)) self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) -- cgit 1.4.1 From 62126c996c5bc8d86d3e15c015a68e8e60622a55 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2015 19:17:58 +0100 Subject: Propogate stale cache errors to calling functions --- synapse/storage/_base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 0872a438f1..524a003153 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -183,8 +183,8 @@ class CacheDescriptor(object): try: cached_result_d = cache.get(*keyargs) + observed = cached_result_d.observe() if DEBUG_CACHES: - @defer.inlineCallbacks def check_result(cached_result): actual_result = yield self.function_to_call(obj, *args, **kwargs) @@ -195,9 +195,10 @@ class CacheDescriptor(object): cached_result, actual_result, ) raise ValueError("Stale cache entry") - cached_result_d.observe().addCallback(check_result) + defer.returnValue(cached_result) + observed.addCallback(check_result) - return cached_result_d.observe() + return observed except KeyError: # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated -- cgit 1.4.1 From 9c5385b53ad6d8ac21227dbc3f150c7d2511b061 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2015 19:26:38 +0100 Subject: s/observed/observer/ --- synapse/storage/_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 524a003153..5997603b3c 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -183,7 +183,7 @@ class CacheDescriptor(object): try: cached_result_d = cache.get(*keyargs) - observed = cached_result_d.observe() + observer = cached_result_d.observe() if DEBUG_CACHES: @defer.inlineCallbacks def check_result(cached_result): @@ -196,9 +196,9 @@ class CacheDescriptor(object): ) raise ValueError("Stale cache entry") defer.returnValue(cached_result) - observed.addCallback(check_result) + observer.addCallback(check_result) - return observed + return observer except KeyError: # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated -- cgit 1.4.1 From 2cd6cb9f65e64ea49e0a8ba7c4851c79821243ad Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 10 Aug 2015 10:38:47 +0100 Subject: Rename keyargs to args in Cache --- synapse/storage/_base.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index a088a6175c..1e2d436660 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -86,8 +86,8 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, keyargs, default=_CacheSentinel): - val = self.cache.get(keyargs, _CacheSentinel) + def get(self, key, default=_CacheSentinel): + val = self.cache.get(key, _CacheSentinel) if val is not _CacheSentinel: cache_counter.inc_hits(self.name) return val @@ -99,29 +99,29 @@ class Cache(object): else: return default - def update(self, sequence, keyargs, value): + def update(self, sequence, key, value): self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) - self.prefill(keyargs, value) + self.prefill(key, value) - def prefill(self, keyargs, value): + def prefill(self, key, value): if self.max_entries is not None: while len(self.cache) >= self.max_entries: self.cache.popitem(last=False) - self.cache[keyargs] = value + self.cache[key] = value - def invalidate(self, keyargs): + def invalidate(self, key): self.check_thread() - if not isinstance(keyargs, tuple): + if not isinstance(key, tuple): raise ValueError("keyargs must be a tuple.") # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 - self.cache.pop(keyargs, None) + self.cache.pop(key, None) def invalidate_all(self): self.check_thread() -- cgit 1.4.1 From 86eaaa885bb4f3d75990874a76c53b2a0870d639 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 10 Aug 2015 13:44:44 +0100 Subject: Rename keyargs to args in CacheDescriptor --- synapse/storage/_base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 1e2d436660..e76ee2779a 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -180,9 +180,9 @@ class CacheDescriptor(object): @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - keyargs = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = self.cache.get(keyargs) + cached_result_d = self.cache.get(cache_key) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -192,7 +192,7 @@ class CacheDescriptor(object): if actual_result != cached_result: logger.error( "Stale cache entry %s%r: cached: %r, actual %r", - self.orig.__name__, keyargs, + self.orig.__name__, cache_key, cached_result, actual_result, ) raise ValueError("Stale cache entry") @@ -212,13 +212,13 @@ class CacheDescriptor(object): ) def onErr(f): - self.cache.invalidate(keyargs) + self.cache.invalidate(cache_key) return f ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, keyargs, ret) + self.cache.update(sequence, cache_key, ret) return ret.observe() -- cgit 1.4.1 From bb0a475c30e63a6d9d749a1f8587582904d3ff92 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 10 Aug 2015 14:27:38 +0100 Subject: Comments --- synapse/storage/_base.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 27713e8b74..826c393cd4 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -234,6 +234,12 @@ class CacheDescriptor(object): class CacheListDescriptor(object): + """Wraps an existing cache to support bulk fetching of keys. + + Given a list of keys it looks in the cache to find any hits, then passes + the list of missing keys to the wrapped fucntion. + """ + def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): self.orig = orig -- cgit 1.4.1 From 4762c276cb460044099a98a4343c3f2b0ce2abe4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2015 11:33:41 +0100 Subject: Docs --- synapse/storage/_base.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 826c393cd4..bdca61d2ee 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -241,6 +241,15 @@ class CacheListDescriptor(object): """ def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + """ + Args: + orig (function) + cache (Cache) + list_name (str): Name of the argument which is the bulk lookup list + num_args (int) + inlineCallbacks (bool): Whether orig is a generator that should + be wrapped by defer.inlineCallbacks + """ self.orig = orig if inlineCallbacks: -- cgit 1.4.1 From 6eaa116867ebfd22718505d1aa29dce4679c87be Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2015 11:35:24 +0100 Subject: Comment --- synapse/storage/_base.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index bdca61d2ee..4d86fe7c72 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -288,6 +288,8 @@ class CacheListDescriptor(object): keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] + # cached is a dict arg -> deferred, where deferred results in a + # 2-tuple (`arg`, `result`) cached = {} missing = [] for arg in list_args: -- cgit 1.4.1 From 53a817518bd27e6838c747a03d605165f96dbd12 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2015 11:40:40 +0100 Subject: Comments --- synapse/storage/_base.py | 2 ++ synapse/storage/state.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 4d86fe7c72..e5441aafb2 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -315,6 +315,8 @@ class CacheListDescriptor(object): ret_d = ObservableDeferred(ret_d) + # We need to create deferreds for each arg in the list so that + # we can insert the new deferred into the cache. for arg in missing: observer = ret_d.observe() observer.addCallback(lambda r, arg: r[arg], arg) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a438530071..ea5fa9de7b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -173,6 +173,9 @@ class StateStore(SQLBaseStore): def _get_state_groups_from_groups(self, groups_and_types): """Returns dictionary state_group -> state event ids + + Args: + groups_and_types (list): list of 2-tuple (`group`, `types`) """ def f(txn): results = {} @@ -284,8 +287,11 @@ class StateStore(SQLBaseStore): def _get_state_for_group_from_cache(self, group, types=None): """Checks if group is in cache. See `_get_state_for_groups` - Returns 2-tuple (`state_dict`, `missing_types`). `missing_types` is the - list of types that aren't in the cache for that group. + Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). + `missing_types` is the list of types that aren't in the cache for that + group, or None if `types` is None. `got_all` is a bool indicating if + we successfully retrieved all requests state from the cache, if False + we need to query the DB for the missing state. """ is_all, state_dict = self._state_group_cache.get(group) @@ -323,11 +329,13 @@ class StateStore(SQLBaseStore): return True return False + got_all = not (missing_types or types is None) + return { k: v for k, v in state_dict.items() if include(k[0], k[1]) - }, missing_types, not missing_types and types is not None + }, missing_types, got_all @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): -- cgit 1.4.1 From 2df8dd9b37f26e3ad0d3647a1e78804a85d48c0c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2015 17:59:32 +0100 Subject: Move all the caches into their own package, synapse.util.caches --- synapse/federation/federation_client.py | 2 +- synapse/state.py | 2 +- synapse/storage/_base.py | 335 +---------------------------- synapse/storage/directory.py | 3 +- synapse/storage/event_federation.py | 3 +- synapse/storage/keys.py | 3 +- synapse/storage/presence.py | 3 +- synapse/storage/push_rule.py | 3 +- synapse/storage/receipts.py | 3 +- synapse/storage/registration.py | 3 +- synapse/storage/room.py | 3 +- synapse/storage/roommember.py | 3 +- synapse/storage/state.py | 5 +- synapse/storage/stream.py | 3 +- synapse/storage/transactions.py | 3 +- synapse/util/caches/__init__.py | 14 ++ synapse/util/caches/descriptors.py | 359 ++++++++++++++++++++++++++++++++ synapse/util/caches/dictionary_cache.py | 109 ++++++++++ synapse/util/caches/expiringcache.py | 115 ++++++++++ synapse/util/caches/lrucache.py | 149 +++++++++++++ synapse/util/dictionary_cache.py | 109 ---------- synapse/util/expiringcache.py | 115 ---------- synapse/util/lrucache.py | 149 ------------- tests/storage/test__base.py | 2 +- tests/util/test_dict_cache.py | 2 +- tests/util/test_lrucache.py | 4 +- 26 files changed, 780 insertions(+), 724 deletions(-) create mode 100644 synapse/util/caches/__init__.py create mode 100644 synapse/util/caches/descriptors.py create mode 100644 synapse/util/caches/dictionary_cache.py create mode 100644 synapse/util/caches/expiringcache.py create mode 100644 synapse/util/caches/lrucache.py delete mode 100644 synapse/util/dictionary_cache.py delete mode 100644 synapse/util/expiringcache.py delete mode 100644 synapse/util/lrucache.py (limited to 'synapse/storage/_base.py') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7736d14fb5..58a6d6a0ed 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -23,7 +23,7 @@ from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, ) from synapse.util import unwrapFirstError -from synapse.util.expiringcache import ExpiringCache +from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.events import FrozenEvent import synapse.metrics diff --git a/synapse/state.py b/synapse/state.py index b5e5d7bbda..1fe4d066bd 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor -from synapse.util.expiringcache import ExpiringCache +from synapse.util.caches.expiringcache import ExpiringCache from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.auth import AuthEventTypes diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index e5441aafb2..1444767a52 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,27 +15,22 @@ import logging from synapse.api.errors import StoreError -from synapse.util.async import ObservableDeferred -from synapse.util import unwrapFirstError from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext -from synapse.util.lrucache import LruCache -from synapse.util.dictionary_cache import DictionaryCache +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.caches.descriptors import Cache import synapse.metrics from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer -from collections import namedtuple, OrderedDict +from collections import namedtuple -import functools -import inspect import sys import time import threading -DEBUG_CACHES = False logger = logging.getLogger(__name__) @@ -51,330 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time") sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) -caches_by_name = {} -cache_counter = metrics.register_cache( - "cache", - lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, - labels=["name"], -) - - -_CacheSentinel = object() - - -class Cache(object): - - def __init__(self, name, max_entries=1000, keylen=1, lru=True): - if lru: - self.cache = LruCache(max_size=max_entries) - self.max_entries = None - else: - self.cache = OrderedDict() - self.max_entries = max_entries - - self.name = name - self.keylen = keylen - self.sequence = 0 - self.thread = None - caches_by_name[name] = self.cache - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get(self, key, default=_CacheSentinel): - val = self.cache.get(key, _CacheSentinel) - if val is not _CacheSentinel: - cache_counter.inc_hits(self.name) - return val - - cache_counter.inc_misses(self.name) - - if default is _CacheSentinel: - raise KeyError() - else: - return default - - def update(self, sequence, key, value): - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value) - - def prefill(self, key, value): - if self.max_entries is not None: - while len(self.cache) >= self.max_entries: - self.cache.popitem(last=False) - - self.cache[key] = value - - def invalidate(self, key): - self.check_thread() - if not isinstance(key, tuple): - raise ValueError("keyargs must be a tuple.") - - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 - self.cache.pop(key, None) - - def invalidate_all(self): - self.check_thread() - self.sequence += 1 - self.cache.clear() - - -class CacheDescriptor(object): - """ A method decorator that applies a memoizing cache around the function. - - This caches deferreds, rather than the results themselves. Deferreds that - fail are removed from the cache. - - The function is presumed to take zero or more arguments, which are used in - a tuple as the key for the cache. Hits are served directly from the cache; - misses use the function body to generate the value. - - The wrapped function has an additional member, a callable called - "invalidate". This can be used to remove individual entries from the cache. - - The wrapped function has another additional callable, called "prefill", - which can be used to insert values into the cache specifically, without - calling the calculation function. - """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=True, - inlineCallbacks=False): - self.orig = orig - - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - - self.max_entries = max_entries - self.num_args = num_args - self.lru = lru - - self.arg_names = inspect.getargspec(orig).args[1:num_args+1] - - if len(self.arg_names) < self.num_args: - raise Exception( - "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" - % (orig.__name__,) - ) - - self.cache = Cache( - name=self.orig.__name__, - max_entries=self.max_entries, - keylen=self.num_args, - lru=self.lru, - ) - - def __get__(self, obj, objtype=None): - - @functools.wraps(self.orig) - def wrapped(*args, **kwargs): - arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) - try: - cached_result_d = self.cache.get(cache_key) - - observer = cached_result_d.observe() - if DEBUG_CACHES: - @defer.inlineCallbacks - def check_result(cached_result): - actual_result = yield self.function_to_call(obj, *args, **kwargs) - if actual_result != cached_result: - logger.error( - "Stale cache entry %s%r: cached: %r, actual %r", - self.orig.__name__, cache_key, - cached_result, actual_result, - ) - raise ValueError("Stale cache entry") - defer.returnValue(cached_result) - observer.addCallback(check_result) - - return observer - except KeyError: - # Get the sequence number of the cache before reading from the - # database so that we can tell if the cache is invalidated - # while the SELECT is executing (SYN-369) - sequence = self.cache.sequence - - ret = defer.maybeDeferred( - self.function_to_call, - obj, *args, **kwargs - ) - - def onErr(f): - self.cache.invalidate(cache_key) - return f - - ret.addErrback(onErr) - - ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, cache_key, ret) - - return ret.observe() - - wrapped.invalidate = self.cache.invalidate - wrapped.invalidate_all = self.cache.invalidate_all - wrapped.prefill = self.cache.prefill - - obj.__dict__[self.orig.__name__] = wrapped - - return wrapped - - -class CacheListDescriptor(object): - """Wraps an existing cache to support bulk fetching of keys. - - Given a list of keys it looks in the cache to find any hits, then passes - the list of missing keys to the wrapped fucntion. - """ - - def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): - """ - Args: - orig (function) - cache (Cache) - list_name (str): Name of the argument which is the bulk lookup list - num_args (int) - inlineCallbacks (bool): Whether orig is a generator that should - be wrapped by defer.inlineCallbacks - """ - self.orig = orig - - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - - self.num_args = num_args - self.list_name = list_name - - self.arg_names = inspect.getargspec(orig).args[1:num_args+1] - self.list_pos = self.arg_names.index(self.list_name) - - self.cache = cache - - self.sentinel = object() - - if len(self.arg_names) < self.num_args: - raise Exception( - "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" - % (orig.__name__,) - ) - - if self.list_name not in self.arg_names: - raise Exception( - "Couldn't see arguments %r for %r." - % (self.list_name, cache.name,) - ) - - def __get__(self, obj, objtype=None): - - @functools.wraps(self.orig) - def wrapped(*args, **kwargs): - arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] - list_args = arg_dict[self.list_name] - - # cached is a dict arg -> deferred, where deferred results in a - # 2-tuple (`arg`, `result`) - cached = {} - missing = [] - for arg in list_args: - key = list(keyargs) - key[self.list_pos] = arg - - try: - res = self.cache.get(tuple(key)).observe() - res.addCallback(lambda r, arg: (arg, r), arg) - cached[arg] = res - except KeyError: - missing.append(arg) - - if missing: - sequence = self.cache.sequence - args_to_call = dict(arg_dict) - args_to_call[self.list_name] = missing - - ret_d = defer.maybeDeferred( - self.function_to_call, - **args_to_call - ) - - ret_d = ObservableDeferred(ret_d) - - # We need to create deferreds for each arg in the list so that - # we can insert the new deferred into the cache. - for arg in missing: - observer = ret_d.observe() - observer.addCallback(lambda r, arg: r[arg], arg) - - observer = ObservableDeferred(observer) - - key = list(keyargs) - key[self.list_pos] = arg - self.cache.update(sequence, tuple(key), observer) - - def invalidate(f, key): - self.cache.invalidate(key) - return f - observer.addErrback(invalidate, tuple(key)) - - res = observer.observe() - res.addCallback(lambda r, arg: (arg, r), arg) - - cached[arg] = res - - return defer.gatherResults( - cached.values(), - consumeErrors=True, - ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) - - obj.__dict__[self.orig.__name__] = wrapped - - return wrapped - - -def cached(max_entries=1000, num_args=1, lru=True): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - lru=lru - ) - - -def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - lru=lru, - inlineCallbacks=True, - ) - - -def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): - return lambda orig: CacheListDescriptor( - orig, - cache=cache, - list_name=list_name, - num_args=num_args, - inlineCallbacks=inlineCallbacks, - ) - class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index f3947bbe89..d92028ea43 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from synapse.api.errors import SynapseError diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 910b6598a7..25cc84eb95 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -15,7 +15,8 @@ from twisted.internet import defer -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from syutil.base64util import encode_base64 import logging diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 49b8e37cfd..ffd6daa880 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore, cachedInlineCallbacks +from _base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 576cf670cc..4f91a2b87c 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from twisted.internet import defer diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9b88ca7b39..5305b7e122 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer import logging diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index b79d6683ca..cac1a5657e 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 4eaa088b36..aa446f94c6 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -17,7 +17,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError, Codes -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached class RegistrationStore(SQLBaseStore): diff --git a/synapse/storage/room.py b/synapse/storage/room.py index dd5bc2c8fb..5e07b7e0e5 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -17,7 +17,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks import collections import logging diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 9f14f38f24..8eee2dfbcc 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -17,7 +17,8 @@ from twisted.internet import defer from collections import namedtuple -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from synapse.api.constants import Membership from synapse.types import UserID diff --git a/synapse/storage/state.py b/synapse/storage/state.py index ea5fa9de7b..79c3b82d9f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached, cachedInlineCallbacks, cachedList +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import ( + cached, cachedInlineCallbacks, cachedList +) from twisted.internet import defer diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index b59fe81004..d7fe423f5a 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -35,7 +35,8 @@ what sort order was used: from twisted.internet import defer -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken from synapse.util.logutils import log_function diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 624da4a9dc..c8c7e6591a 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from collections import namedtuple diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py new file mode 100644 index 0000000000..1a84d94cd9 --- /dev/null +++ b/synapse/util/caches/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py new file mode 100644 index 0000000000..82dd09cf5e --- /dev/null +++ b/synapse/util/caches/descriptors.py @@ -0,0 +1,359 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.util.async import ObservableDeferred +from synapse.util import unwrapFirstError +from synapse.util.caches.lrucache import LruCache +import synapse.metrics + +from twisted.internet import defer + +from collections import OrderedDict + +import functools +import inspect +import threading + +logger = logging.getLogger(__name__) + + +DEBUG_CACHES = False + +metrics = synapse.metrics.get_metrics_for("synapse.util.caches") + +caches_by_name = {} +cache_counter = metrics.register_cache( + "cache", + lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, + labels=["name"], +) + + +_CacheSentinel = object() + + +class Cache(object): + + def __init__(self, name, max_entries=1000, keylen=1, lru=True): + if lru: + self.cache = LruCache(max_size=max_entries) + self.max_entries = None + else: + self.cache = OrderedDict() + self.max_entries = max_entries + + self.name = name + self.keylen = keylen + self.sequence = 0 + self.thread = None + caches_by_name[name] = self.cache + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get(self, key, default=_CacheSentinel): + val = self.cache.get(key, _CacheSentinel) + if val is not _CacheSentinel: + cache_counter.inc_hits(self.name) + return val + + cache_counter.inc_misses(self.name) + + if default is _CacheSentinel: + raise KeyError() + else: + return default + + def update(self, sequence, key, value): + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + self.prefill(key, value) + + def prefill(self, key, value): + if self.max_entries is not None: + while len(self.cache) >= self.max_entries: + self.cache.popitem(last=False) + + self.cache[key] = value + + def invalidate(self, key): + self.check_thread() + if not isinstance(key, tuple): + raise ValueError("keyargs must be a tuple.") + + # Increment the sequence number so that any SELECT statements that + # raced with the INSERT don't update the cache (SYN-369) + self.sequence += 1 + self.cache.pop(key, None) + + def invalidate_all(self): + self.check_thread() + self.sequence += 1 + self.cache.clear() + + +class CacheDescriptor(object): + """ A method decorator that applies a memoizing cache around the function. + + This caches deferreds, rather than the results themselves. Deferreds that + fail are removed from the cache. + + The function is presumed to take zero or more arguments, which are used in + a tuple as the key for the cache. Hits are served directly from the cache; + misses use the function body to generate the value. + + The wrapped function has an additional member, a callable called + "invalidate". This can be used to remove individual entries from the cache. + + The wrapped function has another additional callable, called "prefill", + which can be used to insert values into the cache specifically, without + calling the calculation function. + """ + def __init__(self, orig, max_entries=1000, num_args=1, lru=True, + inlineCallbacks=False): + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + + self.max_entries = max_entries + self.num_args = num_args + self.lru = lru + + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + + self.cache = Cache( + name=self.orig.__name__, + max_entries=self.max_entries, + keylen=self.num_args, + lru=self.lru, + ) + + def __get__(self, obj, objtype=None): + + @functools.wraps(self.orig) + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + try: + cached_result_d = self.cache.get(cache_key) + + observer = cached_result_d.observe() + if DEBUG_CACHES: + @defer.inlineCallbacks + def check_result(cached_result): + actual_result = yield self.function_to_call(obj, *args, **kwargs) + if actual_result != cached_result: + logger.error( + "Stale cache entry %s%r: cached: %r, actual %r", + self.orig.__name__, cache_key, + cached_result, actual_result, + ) + raise ValueError("Stale cache entry") + defer.returnValue(cached_result) + observer.addCallback(check_result) + + return observer + except KeyError: + # Get the sequence number of the cache before reading from the + # database so that we can tell if the cache is invalidated + # while the SELECT is executing (SYN-369) + sequence = self.cache.sequence + + ret = defer.maybeDeferred( + self.function_to_call, + obj, *args, **kwargs + ) + + def onErr(f): + self.cache.invalidate(cache_key) + return f + + ret.addErrback(onErr) + + ret = ObservableDeferred(ret, consumeErrors=True) + self.cache.update(sequence, cache_key, ret) + + return ret.observe() + + wrapped.invalidate = self.cache.invalidate + wrapped.invalidate_all = self.cache.invalidate_all + wrapped.prefill = self.cache.prefill + + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + +class CacheListDescriptor(object): + """Wraps an existing cache to support bulk fetching of keys. + + Given a list of keys it looks in the cache to find any hits, then passes + the list of missing keys to the wrapped fucntion. + """ + + def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + """ + Args: + orig (function) + cache (Cache) + list_name (str): Name of the argument which is the bulk lookup list + num_args (int) + inlineCallbacks (bool): Whether orig is a generator that should + be wrapped by defer.inlineCallbacks + """ + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + + self.num_args = num_args + self.list_name = list_name + + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + self.list_pos = self.arg_names.index(self.list_name) + + self.cache = cache + + self.sentinel = object() + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + + if self.list_name not in self.arg_names: + raise Exception( + "Couldn't see arguments %r for %r." + % (self.list_name, cache.name,) + ) + + def __get__(self, obj, objtype=None): + + @functools.wraps(self.orig) + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] + list_args = arg_dict[self.list_name] + + # cached is a dict arg -> deferred, where deferred results in a + # 2-tuple (`arg`, `result`) + cached = {} + missing = [] + for arg in list_args: + key = list(keyargs) + key[self.list_pos] = arg + + try: + res = self.cache.get(tuple(key)).observe() + res.addCallback(lambda r, arg: (arg, r), arg) + cached[arg] = res + except KeyError: + missing.append(arg) + + if missing: + sequence = self.cache.sequence + args_to_call = dict(arg_dict) + args_to_call[self.list_name] = missing + + ret_d = defer.maybeDeferred( + self.function_to_call, + **args_to_call + ) + + ret_d = ObservableDeferred(ret_d) + + # We need to create deferreds for each arg in the list so that + # we can insert the new deferred into the cache. + for arg in missing: + observer = ret_d.observe() + observer.addCallback(lambda r, arg: r[arg], arg) + + observer = ObservableDeferred(observer) + + key = list(keyargs) + key[self.list_pos] = arg + self.cache.update(sequence, tuple(key), observer) + + def invalidate(f, key): + self.cache.invalidate(key) + return f + observer.addErrback(invalidate, tuple(key)) + + res = observer.observe() + res.addCallback(lambda r, arg: (arg, r), arg) + + cached[arg] = res + + return defer.gatherResults( + cached.values(), + consumeErrors=True, + ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) + + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + +def cached(max_entries=1000, num_args=1, lru=True): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru + ) + + +def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru, + inlineCallbacks=True, + ) + + +def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): + return lambda orig: CacheListDescriptor( + orig, + cache=cache, + list_name=list_name, + num_args=num_args, + inlineCallbacks=inlineCallbacks, + ) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py new file mode 100644 index 0000000000..26d464f4f7 --- /dev/null +++ b/synapse/util/caches/dictionary_cache.py @@ -0,0 +1,109 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.caches.lrucache import LruCache +from collections import namedtuple +import threading +import logging + + +logger = logging.getLogger(__name__) + + +DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) + + +class DictionaryCache(object): + """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. + fetching a subset of dictionary keys for a particular key. + """ + + def __init__(self, name, max_entries=1000): + self.cache = LruCache(max_size=max_entries) + + self.name = name + self.sequence = 0 + self.thread = None + # caches_by_name[name] = self.cache + + class Sentinel(object): + __slots__ = [] + + self.sentinel = Sentinel() + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get(self, key, dict_keys=None): + try: + entry = self.cache.get(key, self.sentinel) + if entry is not self.sentinel: + # cache_counter.inc_hits(self.name) + + if dict_keys is None: + return DictionaryEntry(entry.full, dict(entry.value)) + else: + return DictionaryEntry(entry.full, { + k: entry.value[k] + for k in dict_keys + if k in entry.value + }) + + # cache_counter.inc_misses(self.name) + return DictionaryEntry(False, {}) + except: + logger.exception("get failed") + raise + + def invalidate(self, key): + self.check_thread() + + # Increment the sequence number so that any SELECT statements that + # raced with the INSERT don't update the cache (SYN-369) + self.sequence += 1 + self.cache.pop(key, None) + + def invalidate_all(self): + self.check_thread() + self.sequence += 1 + self.cache.clear() + + def update(self, sequence, key, value, full=False): + try: + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + if full: + self._insert(key, value) + else: + self._update_or_insert(key, value) + except: + logger.exception("update failed") + raise + + def _update_or_insert(self, key, value): + entry = self.cache.setdefault(key, DictionaryEntry(False, {})) + entry.value.update(value) + + def _insert(self, key, value): + self.cache[key] = DictionaryEntry(True, value) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py new file mode 100644 index 0000000000..06d1eea01b --- /dev/null +++ b/synapse/util/caches/expiringcache.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + + +logger = logging.getLogger(__name__) + + +class ExpiringCache(object): + def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, + reset_expiry_on_get=False): + """ + Args: + cache_name (str): Name of this cache, used for logging. + clock (Clock) + max_len (int): Max size of dict. If the dict grows larger than this + then the oldest items get automatically evicted. Default is 0, + which indicates there is no max limit. + expiry_ms (int): How long before an item is evicted from the cache + in milliseconds. Default is 0, indicating items never get + evicted based on time. + reset_expiry_on_get (bool): If true, will reset the expiry time for + an item on access. Defaults to False. + + """ + self._cache_name = cache_name + + self._clock = clock + + self._max_len = max_len + self._expiry_ms = expiry_ms + + self._reset_expiry_on_get = reset_expiry_on_get + + self._cache = {} + + def start(self): + if not self._expiry_ms: + # Don't bother starting the loop if things never expire + return + + def f(): + self._prune_cache() + + self._clock.looping_call(f, self._expiry_ms/2) + + def __setitem__(self, key, value): + now = self._clock.time_msec() + self._cache[key] = _CacheEntry(now, value) + + # Evict if there are now too many items + if self._max_len and len(self._cache.keys()) > self._max_len: + sorted_entries = sorted( + self._cache.items(), + key=lambda (k, v): v.time, + ) + + for k, _ in sorted_entries[self._max_len:]: + self._cache.pop(k) + + def __getitem__(self, key): + entry = self._cache[key] + + if self._reset_expiry_on_get: + entry.time = self._clock.time_msec() + + return entry.value + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def _prune_cache(self): + if not self._expiry_ms: + # zero expiry time means don't expire. This should never get called + # since we have this check in start too. + return + begin_length = len(self._cache) + + now = self._clock.time_msec() + + keys_to_delete = set() + + for key, cache_entry in self._cache.items(): + if now - cache_entry.time > self._expiry_ms: + keys_to_delete.add(key) + + for k in keys_to_delete: + self._cache.pop(k) + + logger.debug( + "[%s] _prune_cache before: %d, after len: %d", + self._cache_name, begin_length, len(self._cache.keys()) + ) + + +class _CacheEntry(object): + def __init__(self, time, value): + self.time = time + self.value = value diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py new file mode 100644 index 0000000000..cacd7e45fa --- /dev/null +++ b/synapse/util/caches/lrucache.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import wraps +import threading + + +class LruCache(object): + """Least-recently-used cache.""" + def __init__(self, max_size): + cache = {} + list_root = [] + list_root[:] = [list_root, list_root, None, None] + + PREV, NEXT, KEY, VALUE = 0, 1, 2, 3 + + lock = threading.Lock() + + def synchronized(f): + @wraps(f) + def inner(*args, **kwargs): + with lock: + return f(*args, **kwargs) + + return inner + + def add_node(key, value): + prev_node = list_root + next_node = prev_node[NEXT] + node = [prev_node, next_node, key, value] + prev_node[NEXT] = node + next_node[PREV] = node + cache[key] = node + + def move_node_to_front(node): + prev_node = node[PREV] + next_node = node[NEXT] + prev_node[NEXT] = next_node + next_node[PREV] = prev_node + prev_node = list_root + next_node = prev_node[NEXT] + node[PREV] = prev_node + node[NEXT] = next_node + prev_node[NEXT] = node + next_node[PREV] = node + + def delete_node(node): + prev_node = node[PREV] + next_node = node[NEXT] + prev_node[NEXT] = next_node + next_node[PREV] = prev_node + cache.pop(node[KEY], None) + + @synchronized + def cache_get(key, default=None): + node = cache.get(key, None) + if node is not None: + move_node_to_front(node) + return node[VALUE] + else: + return default + + @synchronized + def cache_set(key, value): + node = cache.get(key, None) + if node is not None: + move_node_to_front(node) + node[VALUE] = value + else: + add_node(key, value) + if len(cache) > max_size: + delete_node(list_root[PREV]) + + @synchronized + def cache_set_default(key, value): + node = cache.get(key, None) + if node is not None: + return node[VALUE] + else: + add_node(key, value) + if len(cache) > max_size: + delete_node(list_root[PREV]) + return value + + @synchronized + def cache_pop(key, default=None): + node = cache.get(key, None) + if node: + delete_node(node) + return node[VALUE] + else: + return default + + @synchronized + def cache_clear(): + list_root[NEXT] = list_root + list_root[PREV] = list_root + cache.clear() + + @synchronized + def cache_len(): + return len(cache) + + @synchronized + def cache_contains(key): + return key in cache + + self.sentinel = object() + self.get = cache_get + self.set = cache_set + self.setdefault = cache_set_default + self.pop = cache_pop + self.len = cache_len + self.contains = cache_contains + self.clear = cache_clear + + def __getitem__(self, key): + result = self.get(key, self.sentinel) + if result is self.sentinel: + raise KeyError() + else: + return result + + def __setitem__(self, key, value): + self.set(key, value) + + def __delitem__(self, key, value): + result = self.pop(key, self.sentinel) + if result is self.sentinel: + raise KeyError() + + def __len__(self): + return self.len() + + def __contains__(self, key): + return self.contains(key) diff --git a/synapse/util/dictionary_cache.py b/synapse/util/dictionary_cache.py deleted file mode 100644 index c7564cdf0d..0000000000 --- a/synapse/util/dictionary_cache.py +++ /dev/null @@ -1,109 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.util.lrucache import LruCache -from collections import namedtuple -import threading -import logging - - -logger = logging.getLogger(__name__) - - -DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) - - -class DictionaryCache(object): - """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. - fetching a subset of dictionary keys for a particular key. - """ - - def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries) - - self.name = name - self.sequence = 0 - self.thread = None - # caches_by_name[name] = self.cache - - class Sentinel(object): - __slots__ = [] - - self.sentinel = Sentinel() - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get(self, key, dict_keys=None): - try: - entry = self.cache.get(key, self.sentinel) - if entry is not self.sentinel: - # cache_counter.inc_hits(self.name) - - if dict_keys is None: - return DictionaryEntry(entry.full, dict(entry.value)) - else: - return DictionaryEntry(entry.full, { - k: entry.value[k] - for k in dict_keys - if k in entry.value - }) - - # cache_counter.inc_misses(self.name) - return DictionaryEntry(False, {}) - except: - logger.exception("get failed") - raise - - def invalidate(self, key): - self.check_thread() - - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 - self.cache.pop(key, None) - - def invalidate_all(self): - self.check_thread() - self.sequence += 1 - self.cache.clear() - - def update(self, sequence, key, value, full=False): - try: - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - if full: - self._insert(key, value) - else: - self._update_or_insert(key, value) - except: - logger.exception("update failed") - raise - - def _update_or_insert(self, key, value): - entry = self.cache.setdefault(key, DictionaryEntry(False, {})) - entry.value.update(value) - - def _insert(self, key, value): - self.cache[key] = DictionaryEntry(True, value) diff --git a/synapse/util/expiringcache.py b/synapse/util/expiringcache.py deleted file mode 100644 index 06d1eea01b..0000000000 --- a/synapse/util/expiringcache.py +++ /dev/null @@ -1,115 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - - -logger = logging.getLogger(__name__) - - -class ExpiringCache(object): - def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, - reset_expiry_on_get=False): - """ - Args: - cache_name (str): Name of this cache, used for logging. - clock (Clock) - max_len (int): Max size of dict. If the dict grows larger than this - then the oldest items get automatically evicted. Default is 0, - which indicates there is no max limit. - expiry_ms (int): How long before an item is evicted from the cache - in milliseconds. Default is 0, indicating items never get - evicted based on time. - reset_expiry_on_get (bool): If true, will reset the expiry time for - an item on access. Defaults to False. - - """ - self._cache_name = cache_name - - self._clock = clock - - self._max_len = max_len - self._expiry_ms = expiry_ms - - self._reset_expiry_on_get = reset_expiry_on_get - - self._cache = {} - - def start(self): - if not self._expiry_ms: - # Don't bother starting the loop if things never expire - return - - def f(): - self._prune_cache() - - self._clock.looping_call(f, self._expiry_ms/2) - - def __setitem__(self, key, value): - now = self._clock.time_msec() - self._cache[key] = _CacheEntry(now, value) - - # Evict if there are now too many items - if self._max_len and len(self._cache.keys()) > self._max_len: - sorted_entries = sorted( - self._cache.items(), - key=lambda (k, v): v.time, - ) - - for k, _ in sorted_entries[self._max_len:]: - self._cache.pop(k) - - def __getitem__(self, key): - entry = self._cache[key] - - if self._reset_expiry_on_get: - entry.time = self._clock.time_msec() - - return entry.value - - def get(self, key, default=None): - try: - return self[key] - except KeyError: - return default - - def _prune_cache(self): - if not self._expiry_ms: - # zero expiry time means don't expire. This should never get called - # since we have this check in start too. - return - begin_length = len(self._cache) - - now = self._clock.time_msec() - - keys_to_delete = set() - - for key, cache_entry in self._cache.items(): - if now - cache_entry.time > self._expiry_ms: - keys_to_delete.add(key) - - for k in keys_to_delete: - self._cache.pop(k) - - logger.debug( - "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self._cache.keys()) - ) - - -class _CacheEntry(object): - def __init__(self, time, value): - self.time = time - self.value = value diff --git a/synapse/util/lrucache.py b/synapse/util/lrucache.py deleted file mode 100644 index cacd7e45fa..0000000000 --- a/synapse/util/lrucache.py +++ /dev/null @@ -1,149 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from functools import wraps -import threading - - -class LruCache(object): - """Least-recently-used cache.""" - def __init__(self, max_size): - cache = {} - list_root = [] - list_root[:] = [list_root, list_root, None, None] - - PREV, NEXT, KEY, VALUE = 0, 1, 2, 3 - - lock = threading.Lock() - - def synchronized(f): - @wraps(f) - def inner(*args, **kwargs): - with lock: - return f(*args, **kwargs) - - return inner - - def add_node(key, value): - prev_node = list_root - next_node = prev_node[NEXT] - node = [prev_node, next_node, key, value] - prev_node[NEXT] = node - next_node[PREV] = node - cache[key] = node - - def move_node_to_front(node): - prev_node = node[PREV] - next_node = node[NEXT] - prev_node[NEXT] = next_node - next_node[PREV] = prev_node - prev_node = list_root - next_node = prev_node[NEXT] - node[PREV] = prev_node - node[NEXT] = next_node - prev_node[NEXT] = node - next_node[PREV] = node - - def delete_node(node): - prev_node = node[PREV] - next_node = node[NEXT] - prev_node[NEXT] = next_node - next_node[PREV] = prev_node - cache.pop(node[KEY], None) - - @synchronized - def cache_get(key, default=None): - node = cache.get(key, None) - if node is not None: - move_node_to_front(node) - return node[VALUE] - else: - return default - - @synchronized - def cache_set(key, value): - node = cache.get(key, None) - if node is not None: - move_node_to_front(node) - node[VALUE] = value - else: - add_node(key, value) - if len(cache) > max_size: - delete_node(list_root[PREV]) - - @synchronized - def cache_set_default(key, value): - node = cache.get(key, None) - if node is not None: - return node[VALUE] - else: - add_node(key, value) - if len(cache) > max_size: - delete_node(list_root[PREV]) - return value - - @synchronized - def cache_pop(key, default=None): - node = cache.get(key, None) - if node: - delete_node(node) - return node[VALUE] - else: - return default - - @synchronized - def cache_clear(): - list_root[NEXT] = list_root - list_root[PREV] = list_root - cache.clear() - - @synchronized - def cache_len(): - return len(cache) - - @synchronized - def cache_contains(key): - return key in cache - - self.sentinel = object() - self.get = cache_get - self.set = cache_set - self.setdefault = cache_set_default - self.pop = cache_pop - self.len = cache_len - self.contains = cache_contains - self.clear = cache_clear - - def __getitem__(self, key): - result = self.get(key, self.sentinel) - if result is self.sentinel: - raise KeyError() - else: - return result - - def __setitem__(self, key, value): - self.set(key, value) - - def __delitem__(self, key, value): - result = self.pop(key, self.sentinel) - if result is self.sentinel: - raise KeyError() - - def __len__(self): - return self.len() - - def __contains__(self, key): - return self.contains(key) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index abee2f631d..e72cace8ff 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.util.async import ObservableDeferred -from synapse.storage._base import Cache, cached +from synapse.util.caches.descriptors import Cache, cached class CacheTestCase(unittest.TestCase): diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index 79bc1225d6..54ff26cd97 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -17,7 +17,7 @@ from twisted.internet import defer from tests import unittest -from synapse.util.dictionary_cache import DictionaryCache +from synapse.util.caches.dictionary_cache import DictionaryCache class DictCacheTestCase(unittest.TestCase): diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index ab934bf928..fc5a904323 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -16,7 +16,7 @@ from .. import unittest -from synapse.util.lrucache import LruCache +from synapse.util.caches.lrucache import LruCache class LruCacheTestCase(unittest.TestCase): @@ -52,5 +52,3 @@ class LruCacheTestCase(unittest.TestCase): cache["key"] = 1 self.assertEquals(cache.pop("key"), 1) self.assertEquals(cache.pop("key"), None) - - -- cgit 1.4.1 From cd800ad99afa5c689f10fd5e10843b74e4589336 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 1 Sep 2015 09:54:51 +0100 Subject: Lower size of 'stateGroupCache' now that we have data from matrix.org to support doing so --- synapse/storage/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage/_base.py') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 1444767a52..d976e17786 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -167,7 +167,7 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) - self._state_group_cache = DictionaryCache("*stateGroupCache*", 100000) + self._state_group_cache = DictionaryCache("*stateGroupCache*", 2000) self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] -- cgit 1.4.1