summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py454
1 files changed, 374 insertions, 80 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 85b8ec2b8f..4b971efdba 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,21 +13,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from collections import namedtuple
 import logging
+from collections import namedtuple
 
 from six import iteritems, itervalues
 from six.moves import range
 
 from twisted.internet import defer
 
+from synapse.api.constants import EventTypes
+from synapse.api.errors import NotFoundError
+from synapse.storage._base import SQLBaseStore
 from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.engines import PostgresEngine
-from synapse.util.caches import intern_string, get_cache_factor_for
+from synapse.storage.events_worker import EventsWorkerStore
+from synapse.util.caches import get_cache_factor_for, intern_string
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.dictionary_cache import DictionaryCache
 from synapse.util.stringutils import to_ascii
-from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
@@ -45,7 +48,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
         return len(self.delta_ids) if self.delta_ids else 0
 
 
-class StateGroupWorkerStore(SQLBaseStore):
+# this inherits from EventsWorkerStore because it calls self.get_events
+class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     """The parts of StateGroupStore that can be called from workers.
     """
 
@@ -56,9 +60,68 @@ class StateGroupWorkerStore(SQLBaseStore):
     def __init__(self, db_conn, hs):
         super(StateGroupWorkerStore, self).__init__(db_conn, hs)
 
+        # Originally the state store used a single DictionaryCache to cache the
+        # event IDs for the state types in a given state group to avoid hammering
+        # on the state_group* tables.
+        #
+        # The point of using a DictionaryCache is that it can cache a subset
+        # of the state events for a given state group (i.e. a subset of the keys for a
+        # given dict which is an entry in the cache for a given state group ID).
+        #
+        # However, this poses problems when performing complicated queries
+        # on the store - for instance: "give me all the state for this group, but
+        # limit members to this subset of users", as DictionaryCache's API isn't
+        # rich enough to say "please cache any of these fields, apart from this subset".
+        # This is problematic when lazy loading members, which requires this behaviour,
+        # as without it the cache has no choice but to speculatively load all
+        # state events for the group, which negates the efficiency being sought.
+        #
+        # Rather than overcomplicating DictionaryCache's API, we instead split the
+        # state_group_cache into two halves - one for tracking non-member events,
+        # and the other for tracking member_events.  This means that lazy loading
+        # queries can be made in a cache-friendly manner by querying both caches
+        # separately and then merging the result.  So for the example above, you
+        # would query the members cache for a specific subset of state keys
+        # (which DictionaryCache will handle efficiently and fine) and the non-members
+        # cache for all state (which DictionaryCache will similarly handle fine)
+        # and then just merge the results together.
+        #
+        # We size the non-members cache to be smaller than the members cache as the
+        # vast majority of state in Matrix (today) is member events.
+
         self._state_group_cache = DictionaryCache(
-            "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache")
+            "*stateGroupCache*",
+            # TODO: this hasn't been tuned yet
+            50000 * get_cache_factor_for("stateGroupCache")
         )
+        self._state_group_members_cache = DictionaryCache(
+            "*stateGroupMembersCache*",
+            500000 * get_cache_factor_for("stateGroupMembersCache")
+        )
+
+    @defer.inlineCallbacks
+    def get_room_version(self, room_id):
+        """Get the room_version of a given room
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[str]
+
+        Raises:
+            NotFoundError if the room is unknown
+        """
+        # for now we do this by looking at the create event. We may want to cache this
+        # more intelligently in future.
+        state_ids = yield self.get_current_state_ids(room_id)
+        create_id = state_ids.get((EventTypes.Create, ""))
+
+        if not create_id:
+            raise NotFoundError("Unknown room")
+
+        create_event = yield self.get_event(create_id)
+        defer.returnValue(create_event.content.get("room_version", "1"))
 
     @cached(max_entries=100000, iterable=True)
     def get_current_state_ids(self, room_id):
@@ -88,6 +151,69 @@ class StateGroupWorkerStore(SQLBaseStore):
             _get_current_state_ids_txn,
         )
 
+    # FIXME: how should this be cached?
+    def get_filtered_current_state_ids(self, room_id, types, filtered_types=None):
+        """Get the current state event of a given type for a room based on the
+        current_state_events table.  This may not be as up-to-date as the result
+        of doing a fresh state resolution as per state_handler.get_current_state
+        Args:
+            room_id (str)
+            types (list[(Str, (Str|None))]): List of (type, state_key) tuples
+                which are used to filter the state fetched. `state_key` may be
+                None, which matches any `state_key`
+            filtered_types (list[Str]|None): List of types to apply the above filter to.
+        Returns:
+            deferred: dict of (type, state_key) -> event
+        """
+
+        include_other_types = False if filtered_types is None else True
+
+        def _get_filtered_current_state_ids_txn(txn):
+            results = {}
+            sql = """SELECT type, state_key, event_id FROM current_state_events
+                     WHERE room_id = ? %s"""
+            # Turns out that postgres doesn't like doing a list of OR's and
+            # is about 1000x slower, so we just issue a query for each specific
+            # type seperately.
+            if types:
+                clause_to_args = [
+                    (
+                        "AND type = ? AND state_key = ?",
+                        (etype, state_key)
+                    ) if state_key is not None else (
+                        "AND type = ?",
+                        (etype,)
+                    )
+                    for etype, state_key in types
+                ]
+
+                if include_other_types:
+                    unique_types = set(filtered_types)
+                    clause_to_args.append(
+                        (
+                            "AND type <> ? " * len(unique_types),
+                            list(unique_types)
+                        )
+                    )
+            else:
+                # If types is None we fetch all the state, and so just use an
+                # empty where clause with no extra args.
+                clause_to_args = [("", [])]
+            for where_clause, where_args in clause_to_args:
+                args = [room_id]
+                args.extend(where_args)
+                txn.execute(sql % (where_clause,), args)
+                for row in txn:
+                    typ, state_key, event_id = row
+                    key = (intern_string(typ), intern_string(state_key))
+                    results[key] = event_id
+            return results
+
+        return self.runInteraction(
+            "get_filtered_current_state_ids",
+            _get_filtered_current_state_ids_txn,
+        )
+
     @cached(max_entries=10000, iterable=True)
     def get_state_group_delta(self, state_group):
         """Given a state group try to return a previous group and a delta between
@@ -184,8 +310,21 @@ class StateGroupWorkerStore(SQLBaseStore):
         })
 
     @defer.inlineCallbacks
-    def _get_state_groups_from_groups(self, groups, types):
-        """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
+    def _get_state_groups_from_groups(self, groups, types, members=None):
+        """Returns the state groups for a given set of groups, filtering on
+        types of state events.
+
+        Args:
+            groups(list[int]): list of state group IDs to query
+            types (Iterable[str, str|None]|None): list of 2-tuples of the form
+                (`type`, `state_key`), where a `state_key` of `None` matches all
+                state_keys for the `type`. If None, all types are returned.
+            members (bool|None): If not None, then, in addition to any filtering
+                implied by types, the results are also filtered to only include
+                member events (if True), or to exclude member events (if False)
+
+        Returns:
+            dictionary state_group -> (dict of (type, state_key) -> event id)
         """
         results = {}
 
@@ -193,14 +332,17 @@ class StateGroupWorkerStore(SQLBaseStore):
         for chunk in chunks:
             res = yield self.runInteraction(
                 "_get_state_groups_from_groups",
-                self._get_state_groups_from_groups_txn, chunk, types,
+                self._get_state_groups_from_groups_txn, chunk, types, members,
             )
             results.update(res)
 
         defer.returnValue(results)
 
-    def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
+    def _get_state_groups_from_groups_txn(
+        self, txn, groups, types=None, members=None,
+    ):
         results = {group: {} for group in groups}
+
         if types is not None:
             types = list(set(types))  # deduplicate types list
 
@@ -235,10 +377,15 @@ class StateGroupWorkerStore(SQLBaseStore):
                 %s
             """)
 
+            if members is True:
+                sql += " AND type = '%s'" % (EventTypes.Member,)
+            elif members is False:
+                sql += " AND type <> '%s'" % (EventTypes.Member,)
+
             # Turns out that postgres doesn't like doing a list of OR's and
             # is about 1000x slower, so we just issue a query for each specific
             # type seperately.
-            if types:
+            if types is not None:
                 clause_to_args = [
                     (
                         "AND type = ? AND state_key = ?",
@@ -277,10 +424,16 @@ class StateGroupWorkerStore(SQLBaseStore):
                     else:
                         where_clauses.append("(type = ? AND state_key = ?)")
                         where_args.extend([typ[0], typ[1]])
+
                 where_clause = "AND (%s)" % (" OR ".join(where_clauses))
             else:
                 where_clause = ""
 
+            if members is True:
+                where_clause += " AND type = '%s'" % EventTypes.Member
+            elif members is False:
+                where_clause += " AND type <> '%s'" % EventTypes.Member
+
             # We don't use WITH RECURSIVE on sqlite3 as there are distributions
             # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
             for group in groups:
@@ -331,27 +484,30 @@ class StateGroupWorkerStore(SQLBaseStore):
         return results
 
     @defer.inlineCallbacks
-    def get_state_for_events(self, event_ids, types):
+    def get_state_for_events(self, event_ids, types, filtered_types=None):
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event. The state dicts will only have the type/state_keys
         that are in the `types` list.
 
         Args:
-            event_ids (list)
-            types (list): List of (type, state_key) tuples which are used to
-                filter the state fetched. `state_key` may be None, which matches
-                any `state_key`
+            event_ids (list[string])
+            types (list[(str, str|None)]|None): List of (type, state_key) tuples
+                which are used to filter the state fetched. If `state_key` is None,
+                all events are returned of the given type.
+                May be None, which matches any key.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
-            deferred: A list of dicts corresponding to the event_ids given.
-            The dicts are mappings from (type, state_key) -> state_events
+            deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
         """
         event_to_groups = yield self._get_state_group_for_events(
             event_ids,
         )
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, types)
+        group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
 
         state_event_map = yield self.get_events(
             [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
@@ -370,25 +526,30 @@ class StateGroupWorkerStore(SQLBaseStore):
         defer.returnValue({event: event_to_state[event] for event in event_ids})
 
     @defer.inlineCallbacks
-    def get_state_ids_for_events(self, event_ids, types=None):
+    def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
         """
-        Get the state dicts corresponding to a list of events
+        Get the state dicts corresponding to a list of events, containing the event_ids
+        of the state events (as opposed to the events themselves)
 
         Args:
             event_ids(list(str)): events whose state should be returned
-            types(list[(str, str)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. May be None, which
-                matches any key
+            types(list[(str, str|None)]|None): List of (type, state_key) tuples
+                which are used to filter the state fetched. If `state_key` is None,
+                all events are returned of the given type.
+                May be None, which matches any key.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
-            A deferred dict from event_id -> (type, state_key) -> state_event
+            A deferred dict from event_id -> (type, state_key) -> event_id
         """
         event_to_groups = yield self._get_state_group_for_events(
             event_ids,
         )
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, types)
+        group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
 
         event_to_state = {
             event_id: group_to_state[group]
@@ -398,37 +559,45 @@ class StateGroupWorkerStore(SQLBaseStore):
         defer.returnValue({event: event_to_state[event] for event in event_ids})
 
     @defer.inlineCallbacks
-    def get_state_for_event(self, event_id, types=None):
+    def get_state_for_event(self, event_id, types=None, filtered_types=None):
         """
         Get the state dict corresponding to a particular event
 
         Args:
             event_id(str): event whose state should be returned
-            types(list[(str, str)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. May be None, which
-                matches any key
+            types(list[(str, str|None)]|None): List of (type, state_key) tuples
+                which are used to filter the state fetched. If `state_key` is None,
+                all events are returned of the given type.
+                May be None, which matches any key.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
             A deferred dict from (type, state_key) -> state_event
         """
-        state_map = yield self.get_state_for_events([event_id], types)
+        state_map = yield self.get_state_for_events([event_id], types, filtered_types)
         defer.returnValue(state_map[event_id])
 
     @defer.inlineCallbacks
-    def get_state_ids_for_event(self, event_id, types=None):
+    def get_state_ids_for_event(self, event_id, types=None, filtered_types=None):
         """
         Get the state dict corresponding to a particular event
 
         Args:
             event_id(str): event whose state should be returned
-            types(list[(str, str)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. May be None, which
-                matches any key
+            types(list[(str, str|None)]|None): List of (type, state_key) tuples
+                which are used to filter the state fetched. If `state_key` is None,
+                all events are returned of the given type.
+                May be None, which matches any key.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
 
         Returns:
             A deferred dict from (type, state_key) -> state_event
         """
-        state_map = yield self.get_state_ids_for_events([event_id], types)
+        state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types)
         defer.returnValue(state_map[event_id])
 
     @cached(max_entries=50000)
@@ -459,58 +628,76 @@ class StateGroupWorkerStore(SQLBaseStore):
 
         defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
 
-    def _get_some_state_from_cache(self, group, types):
+    def _get_some_state_from_cache(self, cache, group, types, filtered_types=None):
         """Checks if group is in cache. See `_get_state_for_groups`
 
-        Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
-        `missing_types` is the list of types that aren't in the cache for that
-        group. `got_all` is a bool indicating if we successfully retrieved all
+        Args:
+            cache(DictionaryCache): the state group cache to use
+            group(int): The state group to lookup
+            types(list[str, str|None]): List of 2-tuples of the form
+                (`type`, `state_key`), where a `state_key` of `None` matches all
+                state_keys for the `type`.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
+
+        Returns 2-tuple (`state_dict`, `got_all`).
+        `got_all` is a bool indicating if we successfully retrieved all
         requests state from the cache, if False we need to query the DB for the
         missing state.
-
-        Args:
-            group: The state group to lookup
-            types (list): List of 2-tuples of the form (`type`, `state_key`),
-                where a `state_key` of `None` matches all state_keys for the
-                `type`.
         """
-        is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
+        is_all, known_absent, state_dict_ids = cache.get(group)
 
         type_to_key = {}
-        missing_types = set()
+
+        # tracks whether any of our requested types are missing from the cache
+        missing_types = False
 
         for typ, state_key in types:
             key = (typ, state_key)
-            if state_key is None:
+
+            if (
+                state_key is None or
+                (filtered_types is not None and typ not in filtered_types)
+            ):
                 type_to_key[typ] = None
-                missing_types.add(key)
+                # we mark the type as missing from the cache because
+                # when the cache was populated it might have been done with a
+                # restricted set of state_keys, so the wildcard will not work
+                # and the cache may be incomplete.
+                missing_types = True
             else:
                 if type_to_key.get(typ, object()) is not None:
                     type_to_key.setdefault(typ, set()).add(state_key)
 
                 if key not in state_dict_ids and key not in known_absent:
-                    missing_types.add(key)
+                    missing_types = True
 
         sentinel = object()
 
         def include(typ, state_key):
             valid_state_keys = type_to_key.get(typ, sentinel)
             if valid_state_keys is sentinel:
-                return False
+                return filtered_types is not None and typ not in filtered_types
             if valid_state_keys is None:
                 return True
             if state_key in valid_state_keys:
                 return True
             return False
 
-        got_all = is_all or not missing_types
+        got_all = is_all
+        if not got_all:
+            # the cache is incomplete. We may still have got all the results we need, if
+            # we don't have any wildcards in the match list.
+            if not missing_types and filtered_types is None:
+                got_all = True
 
         return {
             k: v for k, v in iteritems(state_dict_ids)
             if include(k[0], k[1])
-        }, missing_types, got_all
+        }, got_all
 
-    def _get_all_state_from_cache(self, group):
+    def _get_all_state_from_cache(self, cache, group):
         """Checks if group is in cache. See `_get_state_for_groups`
 
         Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
@@ -518,18 +705,91 @@ class StateGroupWorkerStore(SQLBaseStore):
         cache, if False we need to query the DB for the missing state.
 
         Args:
+            cache(DictionaryCache): the state group cache to use
             group: The state group to lookup
         """
-        is_all, _, state_dict_ids = self._state_group_cache.get(group)
+        is_all, _, state_dict_ids = cache.get(group)
 
         return state_dict_ids, is_all
 
     @defer.inlineCallbacks
-    def _get_state_for_groups(self, groups, types=None):
-        """Given list of groups returns dict of group -> list of state events
-        with matching types. `types` is a list of `(type, state_key)`, where
-        a `state_key` of None matches all state_keys. If `types` is None then
-        all events are returned.
+    def _get_state_for_groups(self, groups, types=None, filtered_types=None):
+        """Gets the state at each of a list of state groups, optionally
+        filtering by type/state_key
+
+        Args:
+            groups (iterable[int]): list of state groups for which we want
+                to get the state.
+            types (None|iterable[(str, None|str)]):
+                indicates the state type/keys required. If None, the whole
+                state is fetched and returned.
+
+                Otherwise, each entry should be a `(type, state_key)` tuple to
+                include in the response. A `state_key` of None is a wildcard
+                meaning that we require all state with that type.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
+
+        Returns:
+            Deferred[dict[int, dict[(type, state_key), EventBase]]]
+                a dictionary mapping from state group to state dictionary.
+        """
+        if types is not None:
+            non_member_types = [t for t in types if t[0] != EventTypes.Member]
+
+            if filtered_types is not None and EventTypes.Member not in filtered_types:
+                # we want all of the membership events
+                member_types = None
+            else:
+                member_types = [t for t in types if t[0] == EventTypes.Member]
+
+        else:
+            non_member_types = None
+            member_types = None
+
+        non_member_state = yield self._get_state_for_groups_using_cache(
+            groups, self._state_group_cache, non_member_types, filtered_types,
+        )
+        # XXX: we could skip this entirely if member_types is []
+        member_state = yield self._get_state_for_groups_using_cache(
+            # we set filtered_types=None as member_state only ever contain members.
+            groups, self._state_group_members_cache, member_types, None,
+        )
+
+        state = non_member_state
+        for group in groups:
+            state[group].update(member_state[group])
+
+        defer.returnValue(state)
+
+    @defer.inlineCallbacks
+    def _get_state_for_groups_using_cache(
+        self, groups, cache, types=None, filtered_types=None
+    ):
+        """Gets the state at each of a list of state groups, optionally
+        filtering by type/state_key, querying from a specific cache.
+
+        Args:
+            groups (iterable[int]): list of state groups for which we want
+                to get the state.
+            cache (DictionaryCache): the cache of group ids to state dicts which
+                we will pass through - either the normal state cache or the specific
+                members state cache.
+            types (None|iterable[(str, None|str)]):
+                indicates the state type/keys required. If None, the whole
+                state is fetched and returned.
+
+                Otherwise, each entry should be a `(type, state_key)` tuple to
+                include in the response. A `state_key` of None is a wildcard
+                meaning that we require all state with that type.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
+
+        Returns:
+            Deferred[dict[int, dict[(type, state_key), EventBase]]]
+                a dictionary mapping from state group to state dictionary.
         """
         if types:
             types = frozenset(types)
@@ -537,8 +797,8 @@ class StateGroupWorkerStore(SQLBaseStore):
         missing_groups = []
         if types is not None:
             for group in set(groups):
-                state_dict_ids, _, got_all = self._get_some_state_from_cache(
-                    group, types
+                state_dict_ids, got_all = self._get_some_state_from_cache(
+                    cache, group, types, filtered_types
                 )
                 results[group] = state_dict_ids
 
@@ -547,7 +807,7 @@ class StateGroupWorkerStore(SQLBaseStore):
         else:
             for group in set(groups):
                 state_dict_ids, got_all = self._get_all_state_from_cache(
-                    group
+                    cache, group
                 )
 
                 results[group] = state_dict_ids
@@ -556,29 +816,46 @@ class StateGroupWorkerStore(SQLBaseStore):
                     missing_groups.append(group)
 
         if missing_groups:
-            # Okay, so we have some missing_types, lets fetch them.
-            cache_seq_num = self._state_group_cache.sequence
+            # Okay, so we have some missing_types, let's fetch them.
+            cache_seq_num = cache.sequence
+
+            # the DictionaryCache knows if it has *all* the state, but
+            # does not know if it has all of the keys of a particular type,
+            # which makes wildcard lookups expensive unless we have a complete
+            # cache. Hence, if we are doing a wildcard lookup, populate the
+            # cache fully so that we can do an efficient lookup next time.
+
+            if filtered_types or (types and any(k is None for (t, k) in types)):
+                types_to_fetch = None
+            else:
+                types_to_fetch = types
 
             group_to_state_dict = yield self._get_state_groups_from_groups(
-                missing_groups, types
+                missing_groups, types_to_fetch, cache == self._state_group_members_cache,
             )
 
-            # Now we want to update the cache with all the things we fetched
-            # from the database.
             for group, group_state_dict in iteritems(group_to_state_dict):
                 state_dict = results[group]
 
-                state_dict.update(
-                    ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
-                    for k, v in iteritems(group_state_dict)
-                )
-
-                self._state_group_cache.update(
+                # update the result, filtering by `types`.
+                if types:
+                    for k, v in iteritems(group_state_dict):
+                        (typ, _) = k
+                        if (
+                            (k in types or (typ, None) in types) or
+                            (filtered_types and typ not in filtered_types)
+                        ):
+                            state_dict[k] = v
+                else:
+                    state_dict.update(group_state_dict)
+
+                # update the cache with all the things we fetched from the
+                # database.
+                cache.update(
                     cache_seq_num,
                     key=group,
-                    value=state_dict,
-                    full=(types is None),
-                    known_absent=types,
+                    value=group_state_dict,
+                    fetched_keys=types_to_fetch,
                 )
 
         defer.returnValue(results)
@@ -676,16 +953,33 @@ class StateGroupWorkerStore(SQLBaseStore):
                     ],
                 )
 
-            # Prefill the state group cache with this group.
+            # Prefill the state group caches with this group.
             # It's fine to use the sequence like this as the state group map
             # is immutable. (If the map wasn't immutable then this prefill could
             # race with another update)
+
+            current_member_state_ids = {
+                s: ev
+                for (s, ev) in iteritems(current_state_ids)
+                if s[0] == EventTypes.Member
+            }
+            txn.call_after(
+                self._state_group_members_cache.update,
+                self._state_group_members_cache.sequence,
+                key=state_group,
+                value=dict(current_member_state_ids),
+            )
+
+            current_non_member_state_ids = {
+                s: ev
+                for (s, ev) in iteritems(current_state_ids)
+                if s[0] != EventTypes.Member
+            }
             txn.call_after(
                 self._state_group_cache.update,
                 self._state_group_cache.sequence,
                 key=state_group,
-                value=dict(current_state_ids),
-                full=True,
+                value=dict(current_non_member_state_ids),
             )
 
             return state_group