summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/state.py85
1 files changed, 53 insertions, 32 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 79c3b82d9f..1293842361 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -287,42 +287,39 @@ class StateStore(SQLBaseStore):
             f,
         )
 
-    def _get_state_for_group_from_cache(self, group, types=None):
+    def _get_some_state_from_cache(self, group, types):
         """Checks if group is in cache. See `_get_state_for_groups`
 
         Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
         `missing_types` is the list of types that aren't in the cache for that
-        group, or None if `types` is None. `got_all` is a bool indicating if
-        we successfully retrieved all requests state from the cache, if False
-        we need to query the DB for the missing state.
+        group. `got_all` is a bool indicating if we successfully retrieved all
+        requests state from the cache, if False we need to query the DB for the
+        missing state.
+
+        Args:
+            group: The state group to lookup
+            types (list): List of 2-tuples of the form (`type`, `state_key`),
+                where a `state_key` of `None` matches all state_keys for the
+                `type`.
         """
         is_all, state_dict = self._state_group_cache.get(group)
 
         type_to_key = {}
         missing_types = set()
-        if types is not None:
-            for typ, state_key in types:
-                if state_key is None:
-                    type_to_key[typ] = None
-                    missing_types.add((typ, state_key))
-                else:
-                    if type_to_key.get(typ, object()) is not None:
-                        type_to_key.setdefault(typ, set()).add(state_key)
-
-                    if (typ, state_key) not in state_dict:
-                        missing_types.add((typ, state_key))
+        for typ, state_key in types:
+            if state_key is None:
+                type_to_key[typ] = None
+                missing_types.add((typ, state_key))
+            else:
+                if type_to_key.get(typ, object()) is not None:
+                    type_to_key.setdefault(typ, set()).add(state_key)
 
-        if is_all:
-            missing_types = set()
-            if types is None:
-                return state_dict, set(), True
+                if (typ, state_key) not in state_dict:
+                    missing_types.add((typ, state_key))
 
         sentinel = object()
 
         def include(typ, state_key):
-            if types is None:
-                return True
-
             valid_state_keys = type_to_key.get(typ, sentinel)
             if valid_state_keys is sentinel:
                 return False
@@ -340,6 +337,19 @@ class StateStore(SQLBaseStore):
             if include(k[0], k[1])
         }, missing_types, got_all
 
+    def _get_all_state_from_cache(self, group):
+        """Checks if group is in cache. See `_get_state_for_groups`
+
+        Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
+        indicating if we successfully retrieved all requests state from the
+        cache, if False we need to query the DB for the missing state.
+
+        Args:
+            group: The state group to lookup
+        """
+        is_all, state_dict = self._state_group_cache.get(group)
+        return state_dict, is_all
+
     @defer.inlineCallbacks
     def _get_state_for_groups(self, groups, types=None):
         """Given list of groups returns dict of group -> list of state events
@@ -349,18 +359,29 @@ class StateStore(SQLBaseStore):
         """
         results = {}
         missing_groups_and_types = []
-        for group in set(groups):
-            state_dict, missing_types, got_all = self._get_state_for_group_from_cache(
-                group, types
-            )
+        if types is not None:
+            for group in set(groups):
+                state_dict, missing_types, got_all = self._get_some_state_from_cache(
+                    group, types
+                )
+
+                results[group] = state_dict
+
+                if not got_all:
+                    missing_groups_and_types.append((
+                        group,
+                        missing_types
+                    ))
+        else:
+            for group in set(groups):
+                state_dict, got_all = self._get_all_state_from_cache(
+                    group
+                )
 
-            results[group] = state_dict
+                results[group] = state_dict
 
-            if not got_all:
-                missing_groups_and_types.append((
-                    group,
-                    missing_types if types else None
-                ))
+                if not got_all:
+                    missing_groups_and_types.append((group, None))
 
         if not missing_groups_and_types:
             defer.returnValue({