diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 405e6b6770..4ab16e18b2 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -245,8 +245,11 @@ class StateGroupWorkerStore(SQLBaseStore):
if types:
clause_to_args = [
(
- "AND type = ? AND state_key = ?" if state_key is not None else "AND type = ?",
- (etype, state_key) if state_key is not None else (etype)
+ "AND type = ? AND state_key = ?",
+ (etype, state_key)
+ ) if state_key is not None else (
+ "AND type = ?",
+ (etype,)
)
for etype, state_key in types
]
@@ -277,22 +280,25 @@ class StateGroupWorkerStore(SQLBaseStore):
results[group][key] = event_id
else:
where_args = []
+ where_clauses = []
+ wildcard_types = False
if types is not None:
- where_clause = "AND ("
for typ in types:
if typ[1] is None:
- where_clause += "(type = ?) OR "
+ where_clauses.append("(type = ?)")
where_args.extend(typ[0])
+ wildcard_types = True
else:
- where_clause += "(type = ? AND state_key = ?) OR "
+ where_clauses.append("(type = ? AND state_key = ?)")
where_args.extend([typ[0], typ[1]])
if include_other_types:
- where_clause += "(%s) OR " % (
- " AND ".join(["type <> ?"] * len(types)),
+ where_clauses.append(
+ "(" + " AND ".join(["type <> ?"] * len(types)) + ")"
)
where_args.extend(t for (t, _) in types)
- where_clause += "0)" # 0 to terminate the last OR
+
+ where_clause = "AND (%s)" % (" OR ".join(where_clauses))
else:
where_clause = ""
@@ -322,9 +328,17 @@ class StateGroupWorkerStore(SQLBaseStore):
if (typ, state_key) not in results[group]
)
- # If the lengths match then we must have all the types,
- # so no need to go walk further down the tree.
- if types is not None and len(results[group]) == len(types):
+ # If the number of entries in the (type,state_key)->event_id dict
+ # matches the number of (type,state_keys) types we were searching
+ # for, then we must have found them all, so no need to go walk
+ # further down the tree... UNLESS our types filter contained
+ # wildcards (i.e. Nones) in which case we have to do an exhaustive
+ # search
+ if (
+ types is not None and
+ not wildcard_types and
+ len(results[group]) == len(types)
+ ):
break
next_group = self._simple_select_one_onecol_txn(
|