| diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 85acf2ad1e..5673e4aa96 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -20,6 +20,7 @@ from synapse.util.stringutils import to_ascii
 from synapse.storage.engines import PostgresEngine
 
 from twisted.internet import defer
+from collections import namedtuple
 
 import logging
 
@@ -29,6 +30,16 @@ logger = logging.getLogger(__name__)
 MAX_STATE_DELTA_HOPS = 100
 
 
+class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
+    """Return type of get_state_group_delta that implements __len__, which lets
+    us use the itrable flag when caching
+    """
+    __slots__ = []
+
+    def __len__(self):
+        return len(self.delta_ids) if self.delta_ids else 0
+
+
 class StateStore(SQLBaseStore):
     """ Keeps track of the state at a given event.
 
@@ -98,6 +109,46 @@ class StateStore(SQLBaseStore):
             _get_current_state_ids_txn,
         )
 
+    @cached(max_entries=10000, iterable=True)
+    def get_state_group_delta(self, state_group):
+        """Given a state group try to return a previous group and a delta between
+        the old and the new.
+
+        Returns:
+            (prev_group, delta_ids), where both may be None.
+        """
+        def _get_state_group_delta_txn(txn):
+            prev_group = self._simple_select_one_onecol_txn(
+                txn,
+                table="state_group_edges",
+                keyvalues={
+                    "state_group": state_group,
+                },
+                retcol="prev_state_group",
+                allow_none=True,
+            )
+
+            if not prev_group:
+                return _GetStateGroupDelta(None, None)
+
+            delta_ids = self._simple_select_list_txn(
+                txn,
+                table="state_groups_state",
+                keyvalues={
+                    "state_group": state_group,
+                },
+                retcols=("type", "state_key", "event_id",)
+            )
+
+            return _GetStateGroupDelta(prev_group, {
+                (row["type"], row["state_key"]): row["event_id"]
+                for row in delta_ids
+            })
+        return self.runInteraction(
+            "get_state_group_delta",
+            _get_state_group_delta_txn,
+        )
+
     @defer.inlineCallbacks
     def get_state_groups_ids(self, room_id, event_ids):
         if not event_ids:
@@ -184,6 +235,19 @@ class StateStore(SQLBaseStore):
             # We persist as a delta if we can, while also ensuring the chain
             # of deltas isn't tooo long, as otherwise read performance degrades.
             if context.prev_group:
+                is_in_db = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="state_groups",
+                    keyvalues={"id": context.prev_group},
+                    retcol="id",
+                    allow_none=True,
+                )
+                if not is_in_db:
+                    raise Exception(
+                        "Trying to persist state with unpersisted prev_group: %r"
+                        % (context.prev_group,)
+                    )
+
                 potential_hops = self._count_state_group_hops_txn(
                     txn, context.prev_group
                 )
@@ -251,6 +315,12 @@ class StateStore(SQLBaseStore):
             ],
         )
 
+        for event_id, state_group_id in state_groups.iteritems():
+            txn.call_after(
+                self._get_state_group_for_event.prefill,
+                (event_id,), state_group_id
+            )
+
     def _count_state_group_hops_txn(self, txn, state_group):
         """Given a state group, count how many hops there are in the tree.
 
@@ -520,8 +590,8 @@ class StateStore(SQLBaseStore):
         state_map = yield self.get_state_ids_for_events([event_id], types)
         defer.returnValue(state_map[event_id])
 
-    @cached(num_args=2, max_entries=50000)
-    def _get_state_group_for_event(self, room_id, event_id):
+    @cached(max_entries=50000)
+    def _get_state_group_for_event(self, event_id):
         return self._simple_select_one_onecol(
             table="event_to_state_groups",
             keyvalues={
@@ -563,20 +633,22 @@ class StateStore(SQLBaseStore):
                 where a `state_key` of `None` matches all state_keys for the
                 `type`.
         """
-        is_all, state_dict_ids = self._state_group_cache.get(group)
+        is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
 
         type_to_key = {}
         missing_types = set()
+
         for typ, state_key in types:
+            key = (typ, state_key)
             if state_key is None:
                 type_to_key[typ] = None
-                missing_types.add((typ, state_key))
+                missing_types.add(key)
             else:
                 if type_to_key.get(typ, object()) is not None:
                     type_to_key.setdefault(typ, set()).add(state_key)
 
-                if (typ, state_key) not in state_dict_ids:
-                    missing_types.add((typ, state_key))
+                if key not in state_dict_ids and key not in known_absent:
+                    missing_types.add(key)
 
         sentinel = object()
 
@@ -590,7 +662,7 @@ class StateStore(SQLBaseStore):
                 return True
             return False
 
-        got_all = not (missing_types or types is None)
+        got_all = is_all or not missing_types
 
         return {
             k: v for k, v in state_dict_ids.iteritems()
@@ -607,7 +679,7 @@ class StateStore(SQLBaseStore):
         Args:
             group: The state group to lookup
         """
-        is_all, state_dict_ids = self._state_group_cache.get(group)
+        is_all, _, state_dict_ids = self._state_group_cache.get(group)
 
         return state_dict_ids, is_all
 
@@ -624,7 +696,7 @@ class StateStore(SQLBaseStore):
         missing_groups = []
         if types is not None:
             for group in set(groups):
-                state_dict_ids, missing_types, got_all = self._get_some_state_from_cache(
+                state_dict_ids, _, got_all = self._get_some_state_from_cache(
                     group, types
                 )
                 results[group] = state_dict_ids
@@ -653,19 +725,7 @@ class StateStore(SQLBaseStore):
             # Now we want to update the cache with all the things we fetched
             # from the database.
             for group, group_state_dict in group_to_state_dict.iteritems():
-                if types:
-                    # We delibrately put key -> None mappings into the cache to
-                    # cache absence of the key, on the assumption that if we've
-                    # explicitly asked for some types then we will probably ask
-                    # for them again.
-                    state_dict = {
-                        (intern_string(etype), intern_string(state_key)): None
-                        for (etype, state_key) in types
-                    }
-                    state_dict.update(results[group])
-                    results[group] = state_dict
-                else:
-                    state_dict = results[group]
+                state_dict = results[group]
 
                 state_dict.update(
                     ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
@@ -677,17 +737,9 @@ class StateStore(SQLBaseStore):
                     key=group,
                     value=state_dict,
                     full=(types is None),
+                    known_absent=types,
                 )
 
-        # Remove all the entries with None values. The None values were just
-        # used for bookkeeping in the cache.
-        for group, state_dict in results.iteritems():
-            results[group] = {
-                key: event_id
-                for key, event_id in state_dict.iteritems()
-                if event_id
-            }
-
         defer.returnValue(results)
 
     def get_next_state_group(self):
 |