summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/state.py31
-rw-r--r--synapse/util/__init__.py2
2 files changed, 22 insertions, 11 deletions
diff --git a/synapse/state.py b/synapse/state.py
index c45bab5859..7523573f22 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -86,18 +86,19 @@ class StateHandler(object):
             for e_id, _, _ in events
         ]
 
-        cache = self._state_cache.get(set(event_ids), None)
+        cache = self._state_cache.get(frozenset(event_ids), None)
         if cache:
             cache.ts = self.clock.time_msec()
-            defer.returnValue(cache.state_group, cache.state)
-
-        res = yield self.resolve_state_groups(event_ids)
+            state = cache.state
+        else:
+            res = yield self.resolve_state_groups(event_ids)
+            state = res[1]
 
         if event_type:
-            defer.returnValue(res[1].get((event_type, state_key)))
+            defer.returnValue(state.get((event_type, state_key)))
             return
 
-        defer.returnValue(res[1].values())
+        defer.returnValue(state.values())
 
     @defer.inlineCallbacks
     def compute_event_context(self, event, old_state=None):
@@ -198,10 +199,16 @@ class StateHandler(object):
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
-        cache = self._state_cache.get(set(event_ids), None)
+        cache = self._state_cache.get(frozenset(event_ids), None)
         if cache and cache.state_group:
             cache.ts = self.clock.time_msec()
-            defer.returnValue(cache.state_group, cache.state)
+            prev_state = cache.state.get((event_type, state_key), None)
+            if prev_state:
+                prev_state = prev_state.event_id
+                prev_states = [prev_state]
+            else:
+                prev_states = []
+            defer.returnValue((cache.state_group, cache.state, prev_states))
 
         state_groups = yield self.store.get_state_groups(
             event_ids
@@ -232,7 +239,7 @@ class StateHandler(object):
                 ts=self.clock.time_msec()
             )
 
-            self._state_cache[set(event_ids)] = cache
+            self._state_cache[frozenset(event_ids)] = cache
 
             defer.returnValue((name, state, prev_states))
 
@@ -285,7 +292,7 @@ class StateHandler(object):
             ts=self.clock.time_msec()
         )
 
-        self._state_cache[set(event_ids)] = cache
+        self._state_cache[frozenset(event_ids)] = cache
 
         defer.returnValue((None, new_state, prev_states))
 
@@ -372,6 +379,8 @@ class StateHandler(object):
         return sorted(events, key=key_func)
 
     def _prune_cache(self):
+        logger.debug("_prune_cache. before len: ", len(self._state_cache))
+
         now = self.clock.time_msec()
 
         if len(self._state_cache) > 100:
@@ -391,3 +400,5 @@ class StateHandler(object):
 
         for k in keys_to_delete:
             self._state_cache.pop(k)
+
+        logger.debug("_prune_cache. after len: ", len(self._state_cache))
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 1fd5ba5787..fee76b0a9b 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -40,7 +40,7 @@ class Clock(object):
         l.start(msec/1000.0, now=False)
         return l
 
-    def looping_call(self, loop):
+    def stop_looping_call(self, loop):
         loop.stop()
 
     def call_later(self, delay, callback):