summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/state.py63
-rw-r--r--synapse/util/__init__.py10
2 files changed, 72 insertions, 1 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 695a5e7ac4..c45bab5859 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -43,14 +43,30 @@ AuthEventTypes = (
 )
 
 
+class _StateCacheEntry(object):
+    def __init__(self, state, state_group, ts):
+        self.state = state
+        self.state_group = state_group
+        self.ts = ts
+
+
 class StateHandler(object):
     """ Responsible for doing state conflict resolution.
     """
 
     def __init__(self, hs):
+        self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.hs = hs
 
+        # set of event_ids -> _StateCacheEntry.
+        self._state_cache = {}
+
+        def f():
+            self._prune_cache()
+
+        self.clock.looping_call(f, 10*1000)
+
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key=""):
         """ Returns the current state for the room as a list. This is done by
@@ -70,6 +86,11 @@ class StateHandler(object):
             for e_id, _, _ in events
         ]
 
+        cache = self._state_cache.get(set(event_ids), None)
+        if cache:
+            cache.ts = self.clock.time_msec()
+            defer.returnValue(cache.state_group, cache.state)
+
         res = yield self.resolve_state_groups(event_ids)
 
         if event_type:
@@ -177,6 +198,11 @@ class StateHandler(object):
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
+        cache = self._state_cache.get(set(event_ids), None)
+        if cache and cache.state_group:
+            cache.ts = self.clock.time_msec()
+            defer.returnValue(cache.state_group, cache.state)
+
         state_groups = yield self.store.get_state_groups(
             event_ids
         )
@@ -200,6 +226,14 @@ class StateHandler(object):
             else:
                 prev_states = []
 
+            cache = _StateCacheEntry(
+                state=state,
+                state_group=name,
+                ts=self.clock.time_msec()
+            )
+
+            self._state_cache[set(event_ids)] = cache
+
             defer.returnValue((name, state, prev_states))
 
         state = {}
@@ -245,6 +279,14 @@ class StateHandler(object):
         new_state = unconflicted_state
         new_state.update(resolved_state)
 
+        cache = _StateCacheEntry(
+            state=new_state,
+            state_group=None,
+            ts=self.clock.time_msec()
+        )
+
+        self._state_cache[set(event_ids)] = cache
+
         defer.returnValue((None, new_state, prev_states))
 
     @log_function
@@ -328,3 +370,24 @@ class StateHandler(object):
             return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
 
         return sorted(events, key=key_func)
+
+    def _prune_cache(self):
+        now = self.clock.time_msec()
+
+        if len(self._state_cache) > 100:
+            sorted_entries = sorted(
+                self._state_cache.items(),
+                key=lambda k, v: v.ts,
+            )
+
+            for k, _ in sorted_entries[100:]:
+                self._state_cache.pop(k)
+
+        keys_to_delete = set()
+
+        for key, cache_entry in self._state_cache.items():
+            if now - cache_entry.ts > 60*1000:
+                keys_to_delete.add(key)
+
+        for k in keys_to_delete:
+            self._state_cache.pop(k)
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 4e837a918e..1fd5ba5787 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -15,7 +15,7 @@
 
 from synapse.util.logcontext import LoggingContext
 
-from twisted.internet import reactor
+from twisted.internet import reactor, task
 
 import time
 
@@ -35,6 +35,14 @@ class Clock(object):
         """Returns the current system time in miliseconds since epoch."""
         return self.time() * 1000
 
+    def looping_call(self, f, msec):
+        l = task.LoopingCall(f)
+        l.start(msec/1000.0, now=False)
+        return l
+
+    def looping_call(self, loop):
+        loop.stop()
+
     def call_later(self, delay, callback):
         current_context = LoggingContext.current_context()