| diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 45478c7a5a..49abf0ac74 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -386,38 +386,26 @@ class StateStore(SQLBaseStore):
                     # each (type, state_key). However, that was terribly slow
                     # without the right indicies (which we can't add until
                     # after we finish deduping state, which requires this func)
-                    if types is not None:
-                        args = [next_group] + [i for typ in types for i in typ]
-                        txn.execute(
-                            "SELECT type, state_key, event_id FROM state_groups_state"
-                            " WHERE state_group = ? %s" % (where_clause,),
-                            args
-                        )
-                        rows = txn.fetchall()
-
-                        results[group].update({
-                            (typ, state_key): event_id
-                            for typ, state_key, event_id in rows
-                            if (typ, state_key) not in results[group]
-                        })
-
-                        # If the lengths match then we must have all the types,
-                        # so no need to go walk further down the tree.
-                        if len(results[group]) == len(types):
-                            break
-                    else:
-                        txn.execute(
-                            "SELECT type, state_key, event_id FROM state_groups_state"
-                            " WHERE state_group = ?",
-                            (next_group,)
-                        )
-                        rows = txn.fetchall()
+                    args = [next_group]
+                    if types:
+                        args.extend(i for typ in types for i in typ)
 
-                        results[group].update({
-                            (typ, state_key): event_id
-                            for typ, state_key, event_id in rows
-                            if (typ, state_key) not in results[group]
-                        })
+                    txn.execute(
+                        "SELECT type, state_key, event_id FROM state_groups_state"
+                        " WHERE state_group = ? %s" % (where_clause,),
+                        args
+                    )
+                    rows = txn.fetchall()
+                    results[group].update({
+                        (typ, state_key): event_id
+                        for typ, state_key, event_id in rows
+                        if (typ, state_key) not in results[group]
+                    })
+
+                    # If the lengths match then we must have all the types,
+                    # so no need to go walk further down the tree.
+                    if types is not None and len(results[group]) == len(types):
+                        break
 
                     next_group = self._simple_select_one_onecol_txn(
                         txn,
 |