diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 45478c7a5a..49abf0ac74 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -386,38 +386,26 @@ class StateStore(SQLBaseStore):
# each (type, state_key). However, that was terribly slow
# without the right indicies (which we can't add until
# after we finish deduping state, which requires this func)
- if types is not None:
- args = [next_group] + [i for typ in types for i in typ]
- txn.execute(
- "SELECT type, state_key, event_id FROM state_groups_state"
- " WHERE state_group = ? %s" % (where_clause,),
- args
- )
- rows = txn.fetchall()
-
- results[group].update({
- (typ, state_key): event_id
- for typ, state_key, event_id in rows
- if (typ, state_key) not in results[group]
- })
-
- # If the lengths match then we must have all the types,
- # so no need to go walk further down the tree.
- if len(results[group]) == len(types):
- break
- else:
- txn.execute(
- "SELECT type, state_key, event_id FROM state_groups_state"
- " WHERE state_group = ?",
- (next_group,)
- )
- rows = txn.fetchall()
+ args = [next_group]
+ if types:
+ args.extend(i for typ in types for i in typ)
- results[group].update({
- (typ, state_key): event_id
- for typ, state_key, event_id in rows
- if (typ, state_key) not in results[group]
- })
+ txn.execute(
+ "SELECT type, state_key, event_id FROM state_groups_state"
+ " WHERE state_group = ? %s" % (where_clause,),
+ args
+ )
+ rows = txn.fetchall()
+ results[group].update({
+ (typ, state_key): event_id
+ for typ, state_key, event_id in rows
+ if (typ, state_key) not in results[group]
+ })
+
+ # If the lengths match then we must have all the types,
+ # so no need to go walk further down the tree.
+ if types is not None and len(results[group]) == len(types):
+ break
next_group = self._simple_select_one_onecol_txn(
txn,
|