| diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 372b540002..02cefdff26 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,9 +14,8 @@
 # limitations under the License.
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import (
-    cached, cachedInlineCallbacks, cachedList
-)
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches import intern_string
 
 from twisted.internet import defer
 
@@ -83,7 +82,7 @@ class StateStore(SQLBaseStore):
             if event.is_state():
                 state_events[(event.type, event.state_key)] = event
 
-            state_group = self._state_groups_id_gen.get_next_txn(txn)
+            state_group = self._state_groups_id_gen.get_next()
             self._simple_insert_txn(
                 txn,
                 table="state_groups",
@@ -155,8 +154,14 @@ class StateStore(SQLBaseStore):
         events = yield self._get_events(event_ids, get_prev_content=False)
         defer.returnValue(events)
 
-    @cachedInlineCallbacks(num_args=3)
+    @defer.inlineCallbacks
     def get_current_state_for_key(self, room_id, event_type, state_key):
+        event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key)
+        events = yield self._get_events(event_ids, get_prev_content=False)
+        defer.returnValue(events)
+
+    @cached(num_args=3)
+    def _get_current_state_for_key(self, room_id, event_type, state_key):
         def f(txn):
             sql = (
                 "SELECT event_id FROM current_state_events"
@@ -167,12 +172,10 @@ class StateStore(SQLBaseStore):
             txn.execute(sql, args)
             results = txn.fetchall()
             return [r[0] for r in results]
-        event_ids = yield self.runInteraction("get_current_state_for_key", f)
-        events = yield self._get_events(event_ids, get_prev_content=False)
-        defer.returnValue(events)
+        return self.runInteraction("get_current_state_for_key", f)
 
     def _get_state_groups_from_groups(self, groups, types):
-        """Returns dictionary state_group -> state event ids
+        """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
         """
         def f(txn, groups):
             if types is not None:
@@ -183,7 +186,8 @@ class StateStore(SQLBaseStore):
                 where_clause = ""
 
             sql = (
-                "SELECT state_group, event_id FROM state_groups_state WHERE"
+                "SELECT state_group, event_id, type, state_key"
+                " FROM state_groups_state WHERE"
                 " state_group IN (%s) %s" % (
                     ",".join("?" for _ in groups),
                     where_clause,
@@ -199,7 +203,8 @@ class StateStore(SQLBaseStore):
 
             results = {}
             for row in rows:
-                results.setdefault(row["state_group"], []).append(row["event_id"])
+                key = (row["type"], row["state_key"])
+                results.setdefault(row["state_group"], {})[key] = row["event_id"]
             return results
 
         chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
@@ -296,7 +301,7 @@ class StateStore(SQLBaseStore):
                 where a `state_key` of `None` matches all state_keys for the
                 `type`.
         """
-        is_all, state_dict = self._state_group_cache.get(group)
+        is_all, state_dict_ids = self._state_group_cache.get(group)
 
         type_to_key = {}
         missing_types = set()
@@ -308,7 +313,7 @@ class StateStore(SQLBaseStore):
                 if type_to_key.get(typ, object()) is not None:
                     type_to_key.setdefault(typ, set()).add(state_key)
 
-                if (typ, state_key) not in state_dict:
+                if (typ, state_key) not in state_dict_ids:
                     missing_types.add((typ, state_key))
 
         sentinel = object()
@@ -326,7 +331,7 @@ class StateStore(SQLBaseStore):
         got_all = not (missing_types or types is None)
 
         return {
-            k: v for k, v in state_dict.items()
+            k: v for k, v in state_dict_ids.items()
             if include(k[0], k[1])
         }, missing_types, got_all
 
@@ -340,8 +345,9 @@ class StateStore(SQLBaseStore):
         Args:
             group: The state group to lookup
         """
-        is_all, state_dict = self._state_group_cache.get(group)
-        return state_dict, is_all
+        is_all, state_dict_ids = self._state_group_cache.get(group)
+
+        return state_dict_ids, is_all
 
     @defer.inlineCallbacks
     def _get_state_for_groups(self, groups, types=None):
@@ -354,84 +360,72 @@ class StateStore(SQLBaseStore):
         missing_groups = []
         if types is not None:
             for group in set(groups):
-                state_dict, missing_types, got_all = self._get_some_state_from_cache(
+                state_dict_ids, missing_types, got_all = self._get_some_state_from_cache(
                     group, types
                 )
-                results[group] = state_dict
+                results[group] = state_dict_ids
 
                 if not got_all:
                     missing_groups.append(group)
         else:
             for group in set(groups):
-                state_dict, got_all = self._get_all_state_from_cache(
+                state_dict_ids, got_all = self._get_all_state_from_cache(
                     group
                 )
-                results[group] = state_dict
+
+                results[group] = state_dict_ids
 
                 if not got_all:
                     missing_groups.append(group)
 
-        if not missing_groups:
-            defer.returnValue({
-                group: {
-                    type_tuple: event
-                    for type_tuple, event in state.items()
-                    if event
-                }
-                for group, state in results.items()
-            })
+        if missing_groups:
+            # Okay, so we have some missing_types, lets fetch them.
+            cache_seq_num = self._state_group_cache.sequence
 
-        # Okay, so we have some missing_types, lets fetch them.
-        cache_seq_num = self._state_group_cache.sequence
+            group_to_state_dict = yield self._get_state_groups_from_groups(
+                missing_groups, types
+            )
 
-        group_state_dict = yield self._get_state_groups_from_groups(
-            missing_groups, types
-        )
+            # Now we want to update the cache with all the things we fetched
+            # from the database.
+            for group, group_state_dict in group_to_state_dict.items():
+                if types:
+                    # We delibrately put key -> None mappings into the cache to
+                    # cache absence of the key, on the assumption that if we've
+                    # explicitly asked for some types then we will probably ask
+                    # for them again.
+                    state_dict = {
+                        (intern_string(etype), intern_string(state_key)): None
+                        for (etype, state_key) in types
+                    }
+                    state_dict.update(results[group])
+                    results[group] = state_dict
+                else:
+                    state_dict = results[group]
+
+                state_dict.update(group_state_dict)
+
+                self._state_group_cache.update(
+                    cache_seq_num,
+                    key=group,
+                    value=state_dict,
+                    full=(types is None),
+                )
 
         state_events = yield self._get_events(
-            [e_id for l in group_state_dict.values() for e_id in l],
+            [ev_id for sd in results.values() for ev_id in sd.values()],
             get_prev_content=False
         )
 
         state_events = {e.event_id: e for e in state_events}
 
-        # Now we want to update the cache with all the things we fetched
-        # from the database.
-        for group, state_ids in group_state_dict.items():
-            if types:
-                # We delibrately put key -> None mappings into the cache to
-                # cache absence of the key, on the assumption that if we've
-                # explicitly asked for some types then we will probably ask
-                # for them again.
-                state_dict = {key: None for key in types}
-                state_dict.update(results[group])
-                results[group] = state_dict
-            else:
-                state_dict = results[group]
-
-            for event_id in state_ids:
-                try:
-                    state_event = state_events[event_id]
-                    state_dict[(state_event.type, state_event.state_key)] = state_event
-                except KeyError:
-                    # Hmm. So we do don't have that state event? Interesting.
-                    logger.warn(
-                        "Can't find state event %r for state group %r",
-                        event_id, group,
-                    )
-
-            self._state_group_cache.update(
-                cache_seq_num,
-                key=group,
-                value=state_dict,
-                full=(types is None),
-            )
-
         # Remove all the entries with None values. The None values were just
         # used for bookkeeping in the cache.
         for group, state_dict in results.items():
             results[group] = {
-                key: event for key, event in state_dict.items() if event
+                key: state_events[event_id]
+                for key, event_id in state_dict.items()
+                if event_id and event_id in state_events
             }
 
         defer.returnValue(results)
 |