summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/state.py14
2 files changed, 13 insertions, 3 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 4d86fe7c72..e5441aafb2 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -315,6 +315,8 @@ class CacheListDescriptor(object):
 
                 ret_d = ObservableDeferred(ret_d)
 
+                # We need to create deferreds for each arg in the list so that
+                # we can insert the new deferred into the cache.
                 for arg in missing:
                     observer = ret_d.observe()
                     observer.addCallback(lambda r, arg: r[arg], arg)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index a438530071..ea5fa9de7b 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -173,6 +173,9 @@ class StateStore(SQLBaseStore):
 
     def _get_state_groups_from_groups(self, groups_and_types):
         """Returns dictionary state_group -> state event ids
+
+        Args:
+            groups_and_types (list): list of 2-tuple (`group`, `types`)
         """
         def f(txn):
             results = {}
@@ -284,8 +287,11 @@ class StateStore(SQLBaseStore):
     def _get_state_for_group_from_cache(self, group, types=None):
         """Checks if group is in cache. See `_get_state_for_groups`
 
-        Returns 2-tuple (`state_dict`, `missing_types`). `missing_types` is the
-        list of types that aren't in the cache for that group.
+        Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
+        `missing_types` is the list of types that aren't in the cache for that
+        group, or None if `types` is None. `got_all` is a bool indicating if
+        we successfully retrieved all requests state from the cache, if False
+        we need to query the DB for the missing state.
         """
         is_all, state_dict = self._state_group_cache.get(group)
 
@@ -323,11 +329,13 @@ class StateStore(SQLBaseStore):
                 return True
             return False
 
+        got_all = not (missing_types or types is None)
+
         return {
             k: v
             for k, v in state_dict.items()
             if include(k[0], k[1])
-        }, missing_types, not missing_types and types is not None
+        }, missing_types, got_all
 
     @defer.inlineCallbacks
     def _get_state_for_groups(self, groups, types=None):