summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/state.py22
1 files changed, 16 insertions, 6 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 48a4023558..e924258d11 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -234,7 +234,8 @@ class StateStore(SQLBaseStore):
             ])
 
             sql = (
-                "SELECT sg.event_id FROM state_groups_state as sg"
+                "SELECT e.event_id, sg.state_group, sg.event_id"
+                " FROM state_groups_state as sg"
                 " INNER JOIN event_to_state_groups as e"
                 " ON e.state_group = sg.state_group"
                 " WHERE e.event_id = ? AND (%s)"
@@ -342,8 +343,9 @@ class StateStore(SQLBaseStore):
             defer.returnValue(state_dict)
 
         if is_all or (types is not None and not missing_types):
+            sentinel = object()
+
             def include(typ, state_key):
-                sentinel = object()
                 valid_state_keys = type_to_key.get(typ, sentinel)
                 if valid_state_keys is sentinel:
                     return False
@@ -356,20 +358,24 @@ class StateStore(SQLBaseStore):
             defer.returnValue({
                 k: v
                 for k, v in state_dict.items()
-                if include(k[0], k[1])
+                if v and include(k[0], k[1])
             })
 
         # Okay, so we have some missing_types, lets fetch them.
         cache_seq_num = self._state_group_cache.sequence
         _, state_ids = yield self._get_state_groups_from_group(
             group,
-            frozenset(types) if types else None
+            frozenset(missing_types) if types else None
         )
         state_events = yield self._get_events(state_ids, get_prev_content=False)
         state_dict = {
+            key: None
+            for key in missing_types
+        }
+        state_dict.update({
             (e.type, e.state_key): e
             for e in state_events
-        }
+        })
 
         # Update the cache
         self._state_group_cache.update(
@@ -379,7 +385,11 @@ class StateStore(SQLBaseStore):
             full=(types is None),
         )
 
-        defer.returnValue(state_dict)
+        defer.returnValue({
+            key: value
+            for key, value in state_dict.items()
+            if value
+        })
 
 
 def _make_group_id(clock):