summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/state.py85
1 files changed, 49 insertions, 36 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 64c5ae9928..19b16ed404 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -283,6 +283,9 @@ class StateStore(SQLBaseStore):
 
     def _get_state_for_group_from_cache(self, group, types=None):
         """Checks if group is in cache. See `_get_state_for_groups`
+
+        Returns 2-tuple (`state_dict`, `missing_types`). `missing_types` is the
+        list of types that aren't in the cache for that group.
         """
         is_all, state_dict = self._state_group_cache.get(group)
 
@@ -300,29 +303,31 @@ class StateStore(SQLBaseStore):
                     if (typ, state_key) not in state_dict:
                         missing_types.add((typ, state_key))
 
-        if is_all and types is None:
-            return state_dict, missing_types
-
-        if is_all or (types is not None and not missing_types):
-            sentinel = object()
+        if is_all:
+            missing_types = set()
+            if types is None:
+                return state_dict, set(), True
 
-            def include(typ, state_key):
-                valid_state_keys = type_to_key.get(typ, sentinel)
-                if valid_state_keys is sentinel:
-                    return False
-                if valid_state_keys is None:
-                    return True
-                if state_key in valid_state_keys:
-                    return True
-                return False
+        sentinel = object()
 
-            return {
-                k: v
-                for k, v in state_dict.items()
-                if v and include(k[0], k[1])
-            }, missing_types
+        def include(typ, state_key):
+            if types is None:
+                return True
 
-        return {}, missing_types
+            valid_state_keys = type_to_key.get(typ, sentinel)
+            if valid_state_keys is sentinel:
+                return False
+            if valid_state_keys is None:
+                return True
+            if state_key in valid_state_keys:
+                return True
+            return False
+
+        return {
+            k: v
+            for k, v in state_dict.items()
+            if include(k[0], k[1])
+        }, missing_types, not missing_types and types is not None
 
     @defer.inlineCallbacks
     def _get_state_for_groups(self, groups, types=None):
@@ -333,25 +338,28 @@ class StateStore(SQLBaseStore):
         """
         results = {}
         missing_groups_and_types = []
-        for group in groups:
-            state_dict, missing_types = self._get_state_for_group_from_cache(
+        for group in set(groups):
+            state_dict, missing_types, got_all = self._get_state_for_group_from_cache(
                 group, types
             )
 
-            if types is not None and not missing_types:
-                results[group] = {
-                    key: value
-                    for key, value in state_dict.items()
-                    if value
-                }
-            else:
+            results[group] = state_dict
+
+            if not got_all:
                 missing_groups_and_types.append((
                     group,
                     missing_types if types else None
                 ))
 
         if not missing_groups_and_types:
-            defer.returnValue(results)
+            defer.returnValue({
+                k: {
+                    key: ev
+                    for key, ev in state.items()
+                    if ev
+                }
+                for k, state in results.items()
+            })
 
         # Okay, so we have some missing_types, lets fetch them.
         cache_seq_num = self._state_group_cache.sequence
@@ -371,10 +379,15 @@ class StateStore(SQLBaseStore):
         }
 
         for group, state_ids in group_state_dict.items():
-            state_dict = {
-                key: None
-                for key in missing_types
-            }
+            if types:
+                state_dict = {
+                    key: None
+                    for key in types
+                }
+                state_dict.update(results[group])
+            else:
+                state_dict = results[group]
+
             evs = [
                 state_events[e_id] for e_id in state_ids
                 if e_id in state_events  # This can happen if event is rejected.
@@ -392,11 +405,11 @@ class StateStore(SQLBaseStore):
                 full=(types is None),
             )
 
-            results[group] = {
+            results[group].update({
                 key: value
                 for key, value in state_dict.items()
                 if value
-            }
+            })
 
         defer.returnValue(results)