summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 2b325e1c1f..783cebb351 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -240,6 +240,10 @@ class StateGroupWorkerStore(SQLBaseStore):
                     (
                         "AND type = ? AND state_key = ?",
                         (etype, state_key)
+                    ) if state_key is not None else
+                    (
+                        "AND type = ?",
+                        (etype)
                     )
                     for etype, state_key in types
                 ]
@@ -259,10 +263,19 @@ class StateGroupWorkerStore(SQLBaseStore):
                         key = (typ, state_key)
                         results[group][key] = event_id
         else:
+            where_args = []
             if types is not None:
-                where_clause = "AND (%s)" % (
-                    " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
-                )
+                where_clause = "AND ("
+                for typ in types:
+                    if typ[1] is None:
+                        where_clause += "(type = ?)"
+                        where_args.extend(typ[0])
+                    else:
+                        where_clause += "(type = ? AND state_key = ?)"
+                        where_args.extend([typ[0], typ[1]])
+                    if typ != types[-1]:
+                        where_clause += " OR "
+                where_clause += ")"
             else:
                 where_clause = ""
 
@@ -279,7 +292,7 @@ class StateGroupWorkerStore(SQLBaseStore):
                     # after we finish deduping state, which requires this func)
                     args = [next_group]
                     if types:
-                        args.extend(i for typ in types for i in typ)
+                        args.extend(where_args)
 
                     txn.execute(
                         "SELECT type, state_key, event_id FROM state_groups_state"