| diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 5f4fcd9832..3f6833fad2 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -428,6 +428,7 @@ class EventsStore(SQLBaseStore):
         # Now we need to work out the different state sets for
         # each state extremities
         state_sets = []
+        state_groups = set()
         missing_event_ids = []
         was_updated = False
         for event_id in new_latest_event_ids:
@@ -437,9 +438,17 @@ class EventsStore(SQLBaseStore):
                 if event_id == ev.event_id:
                     if ctx.current_state_ids is None:
                         raise Exception("Unknown current state")
-                    state_sets.append(ctx.current_state_ids)
-                    if ctx.delta_ids or hasattr(ev, "state_key"):
-                        was_updated = True
+
+                    # If we've already seen the state group don't bother adding
+                    # it to the state sets again
+                    if ctx.state_group not in state_groups:
+                        state_sets.append(ctx.current_state_ids)
+                        if ctx.delta_ids or hasattr(ev, "state_key"):
+                            was_updated = True
+                        if ctx.state_group:
+                            # Add this as a seen state group (if it has a state
+                            # group)
+                            state_groups.add(ctx.state_group)
                     break
             else:
                 # If we couldn't find it, then we'll need to pull
@@ -453,45 +462,51 @@ class EventsStore(SQLBaseStore):
                 missing_event_ids,
             )
 
-            groups = set(event_to_groups.itervalues())
-            group_to_state = yield self._get_state_for_groups(groups)
+            groups = set(event_to_groups.itervalues()) - state_groups
 
-            state_sets.extend(group_to_state.itervalues())
+            if groups:
+                group_to_state = yield self._get_state_for_groups(groups)
+                state_sets.extend(group_to_state.itervalues())
 
         if not new_latest_event_ids:
             current_state = {}
         elif was_updated:
-            # We work out the current state by passing the state sets to the
-            # state resolution algorithm. It may ask for some events, including
-            # the events we have yet to persist, so we need a slightly more
-            # complicated event lookup function than simply looking the events
-            # up in the db.
-            events_map = {ev.event_id: ev for ev, _ in events_context}
-
-            @defer.inlineCallbacks
-            def get_events(ev_ids):
-                # We get the events by first looking at the list of events we
-                # are trying to persist, and then fetching the rest from the DB.
-                db = []
-                to_return = {}
-                for ev_id in ev_ids:
-                    ev = events_map.get(ev_id, None)
-                    if ev:
-                        to_return[ev_id] = ev
-                    else:
-                        db.append(ev_id)
-
-                if db:
-                    evs = yield self.get_events(
-                        ev_ids, get_prev_content=False, check_redacted=False,
-                    )
-                    to_return.update(evs)
-                defer.returnValue(to_return)
+            if len(state_sets) == 1:
+                # If there is only one state set, then we know what the current
+                # state is.
+                current_state = state_sets[0]
+            else:
+                # We work out the current state by passing the state sets to the
+                # state resolution algorithm. It may ask for some events, including
+                # the events we have yet to persist, so we need a slightly more
+                # complicated event lookup function than simply looking the events
+                # up in the db.
+                events_map = {ev.event_id: ev for ev, _ in events_context}
+
+                @defer.inlineCallbacks
+                def get_events(ev_ids):
+                    # We get the events by first looking at the list of events we
+                    # are trying to persist, and then fetching the rest from the DB.
+                    db = []
+                    to_return = {}
+                    for ev_id in ev_ids:
+                        ev = events_map.get(ev_id, None)
+                        if ev:
+                            to_return[ev_id] = ev
+                        else:
+                            db.append(ev_id)
 
-            current_state = yield resolve_events(
-                state_sets,
-                state_map_factory=get_events,
-            )
+                    if db:
+                        evs = yield self.get_events(
+                            ev_ids, get_prev_content=False, check_redacted=False,
+                        )
+                        to_return.update(evs)
+                    defer.returnValue(to_return)
+
+                current_state = yield resolve_events(
+                    state_sets,
+                    state_map_factory=get_events,
+                )
         else:
             return
 |