diff options
-rwxr-xr-x | synapse/app/homeserver.py | 1 | ||||
-rw-r--r-- | synapse/state.py | 108 | ||||
-rw-r--r-- | tests/test_state.py | 7 |
3 files changed, 112 insertions, 4 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8976ff2e82..2b17cae54f 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -274,6 +274,7 @@ def setup(): hs.get_pusherpool().start() + hs.get_state_handler().start_caching() hs.get_datastore().start_profiling() if config.daemonize: diff --git a/synapse/state.py b/synapse/state.py index 54380b9e5c..64c58a3934 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -43,14 +43,43 @@ AuthEventTypes = ( ) +SIZE_OF_CACHE = 1000 +EVICTION_TIMEOUT_SECONDS = 20 + + +class _StateCacheEntry(object): + def __init__(self, state, state_group, ts): + self.state = state + self.state_group = state_group + self.ts = ts + + class StateHandler(object): """ Responsible for doing state conflict resolution. """ def __init__(self, hs): + self.clock = hs.get_clock() self.store = hs.get_datastore() self.hs = hs + # dict of set of event_ids -> _StateCacheEntry. + self._state_cache = None + + def start_caching(self): + logger.debug("start_caching") + + self._state_cache = {} + + def f(): + logger.debug("Pruning") + try: + self._prune_cache() + except: + logger.exception("Prune") + + self.clock.looping_call(f, 5*1000) + @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): """ Returns the current state for the room as a list. This is done by @@ -70,13 +99,22 @@ class StateHandler(object): for e_id, _, _ in events ] - res = yield self.resolve_state_groups(event_ids) + cache = None + if self._state_cache is not None: + cache = self._state_cache.get(frozenset(event_ids), None) + + if cache: + cache.ts = self.clock.time_msec() + state = cache.state + else: + res = yield self.resolve_state_groups(event_ids) + state = res[1] if event_type: - defer.returnValue(res[1].get((event_type, state_key))) + defer.returnValue(state.get((event_type, state_key))) return - defer.returnValue(res[1]) + defer.returnValue(state) @defer.inlineCallbacks def compute_event_context(self, event, old_state=None): @@ -177,6 +215,20 @@ class StateHandler(object): """ logger.debug("resolve_state_groups event_ids %s", event_ids) + if self._state_cache is not None: + cache = self._state_cache.get(frozenset(event_ids), None) + if cache and cache.state_group: + cache.ts = self.clock.time_msec() + prev_state = cache.state.get((event_type, state_key), None) + if prev_state: + prev_state = prev_state.event_id + prev_states = [prev_state] + else: + prev_states = [] + defer.returnValue( + (cache.state_group, cache.state, prev_states) + ) + state_groups = yield self.store.get_state_groups( event_ids ) @@ -200,6 +252,15 @@ class StateHandler(object): else: prev_states = [] + if self._state_cache is not None: + cache = _StateCacheEntry( + state=state, + state_group=name, + ts=self.clock.time_msec() + ) + + self._state_cache[frozenset(event_ids)] = cache + defer.returnValue((name, state, prev_states)) state = {} @@ -245,6 +306,15 @@ class StateHandler(object): new_state = unconflicted_state new_state.update(resolved_state) + if self._state_cache is not None: + cache = _StateCacheEntry( + state=new_state, + state_group=None, + ts=self.clock.time_msec() + ) + + self._state_cache[frozenset(event_ids)] = cache + defer.returnValue((None, new_state, prev_states)) @log_function @@ -328,3 +398,35 @@ class StateHandler(object): return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() return sorted(events, key=key_func) + + def _prune_cache(self): + logger.debug("_prune_cache") + logger.debug( + "_prune_cache. before len: %d", + len(self._state_cache.keys()) + ) + + now = self.clock.time_msec() + + if len(self._state_cache.keys()) > SIZE_OF_CACHE: + sorted_entries = sorted( + self._state_cache.items(), + key=lambda k, v: v.ts, + ) + + for k, _ in sorted_entries[SIZE_OF_CACHE:]: + self._state_cache.pop(k) + + keys_to_delete = set() + + for key, cache_entry in self._state_cache.items(): + if now - cache_entry.ts > EVICTION_TIMEOUT_SECONDS*1000: + keys_to_delete.add(key) + + for k in keys_to_delete: + self._state_cache.pop(k) + + logger.debug( + "_prune_cache. after len: %d", + len(self._state_cache.keys()) + ) diff --git a/tests/test_state.py b/tests/test_state.py index 019e794aa2..fea25f7021 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -21,6 +21,8 @@ from synapse.api.auth import Auth from synapse.api.constants import EventTypes, Membership from synapse.state import StateHandler +from .utils import MockClock + from mock import Mock @@ -138,10 +140,13 @@ class StateTestCase(unittest.TestCase): "add_event_hashes", ] ) - hs = Mock(spec=["get_datastore", "get_auth", "get_state_handler"]) + hs = Mock(spec=[ + "get_datastore", "get_auth", "get_state_handler", "get_clock", + ]) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_auth.return_value = Auth(hs) + hs.get_clock.return_value = MockClock() self.state = StateHandler(hs) self.event_id = 0 |