summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/state.py61
1 files changed, 41 insertions, 20 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 7eb342674c..a82ba1d1d9 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -307,6 +307,9 @@ class StateStore(SQLBaseStore):
 
     def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
         results = {group: {} for group in groups}
+        if types is not None:
+            types = list(set(types))  # deduplicate types list
+
         if isinstance(self.database_engine, PostgresEngine):
             # Temporarily disable sequential scans in this transaction. This is
             # a temporary hack until we can add the right indices in
@@ -379,6 +382,44 @@ class StateStore(SQLBaseStore):
                 next_group = group
 
                 while next_group:
+                    # We did this before by getting the list of group ids, and
+                    # then passing that list to sqlite to get latest event for
+                    # each (type, state_key). However, that was terribly slow
+                    # without the right indicies (which we can't add until
+                    # after we finish deduping state, which requires this func)
+                    if types is not None:
+                        args = [next_group] + [i for typ in types for i in typ]
+                        txn.execute(
+                            "SELECT type, state_key, event_id FROM state_groups_state"
+                            " WHERE state_group = ? %s" % (where_clause,),
+                            args
+                        )
+                        rows = txn.fetchall()
+
+                        results[group].update({
+                            (typ, state_key): event_id
+                            for typ, state_key, event_id in rows
+                            if (typ, state_key) not in results[group]
+                        })
+
+                        # If the lengths match then we must have all the types,
+                        # so no need to go walk further down the tree.
+                        if len(results[group]) == len(types):
+                            break
+                    else:
+                        txn.execute(
+                            "SELECT type, state_key, event_id FROM state_groups_state"
+                            " WHERE state_group = ?",
+                            (next_group,)
+                        )
+                        rows = txn.fetchall()
+
+                        results[group].update({
+                            (typ, state_key): event_id
+                            for typ, state_key, event_id in rows
+                            if (typ, state_key) not in results[group]
+                        })
+
                     next_group = self._simple_select_one_onecol_txn(
                         txn,
                         table="state_group_edges",
@@ -389,26 +430,6 @@ class StateStore(SQLBaseStore):
                     if next_group:
                         group_tree.append(next_group)
 
-                sql = ("""
-                    SELECT type, state_key, event_id FROM state_groups_state
-                    INNER JOIN (
-                        SELECT type, state_key, max(state_group) as state_group
-                        FROM state_groups_state
-                        WHERE state_group IN (%s) %s
-                        GROUP BY type, state_key
-                    ) USING (type, state_key, state_group);
-                """) % (",".join("?" for _ in group_tree), where_clause,)
-
-                args = list(group_tree)
-                if types is not None:
-                    args.extend([i for typ in types for i in typ])
-
-                txn.execute(sql, args)
-                rows = self.cursor_to_dict(txn)
-                for row in rows:
-                    key = (row["type"], row["state_key"])
-                    results[group][key] = row["event_id"]
-
         return results
 
     @defer.inlineCallbacks