summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/state.py26
1 files changed, 16 insertions, 10 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b796d3c995..405e6b6770 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -276,15 +276,23 @@ class StateGroupWorkerStore(SQLBaseStore):
                         key = (typ, state_key)
                         results[group][key] = event_id
         else:
+            where_args = []
             if types is not None:
-                where_clause = "AND (%s" % (
-                    " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
-                )
+                where_clause = "AND ("
+                for typ in types:
+                    if typ[1] is None:
+                        where_clause += "(type = ?) OR "
+                        where_args.extend(typ[0])
+                    else:
+                        where_clause += "(type = ? AND state_key = ?) OR "
+                        where_args.extend([typ[0], typ[1]])
+
                 if include_other_types:
-                    where_clause += " OR (%s)" % (
+                    where_clause += "(%s) OR " % (
                         " AND ".join(["type <> ?"] * len(types)),
                     )
-                where_clause += ")"
+                    where_args.extend(t for (t, _) in types)
+                where_clause += "0)"  # 0 to terminate the last OR
             else:
                 where_clause = ""
 
@@ -301,9 +309,7 @@ class StateGroupWorkerStore(SQLBaseStore):
                     # after we finish deduping state, which requires this func)
                     args = [next_group]
                     if types:
-                        args.extend(i for typ in types for i in typ)
-                        if include_other_types:
-                            args.extend(typ for (typ, _) in types)
+                        args.extend(where_args)
 
                     txn.execute(
                         "SELECT type, state_key, event_id FROM state_groups_state"
@@ -507,12 +513,12 @@ class StateGroupWorkerStore(SQLBaseStore):
         def include(typ, state_key):
             valid_state_keys = type_to_key.get(typ, sentinel)
             if valid_state_keys is sentinel:
-                return False
+                return include_other_types
             if valid_state_keys is None:
                 return True
             if state_key in valid_state_keys:
                 return True
-            return include_other_types
+            return False
 
         got_all = is_all or not missing_types