summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py52
1 files changed, 29 insertions, 23 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 7eb342674c..49abf0ac74 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -307,6 +307,9 @@ class StateStore(SQLBaseStore):
 
     def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
         results = {group: {} for group in groups}
+        if types is not None:
+            types = list(set(types))  # deduplicate types list
+
         if isinstance(self.database_engine, PostgresEngine):
             # Temporarily disable sequential scans in this transaction. This is
             # a temporary hack until we can add the right indices in
@@ -375,10 +378,35 @@ class StateStore(SQLBaseStore):
             # We don't use WITH RECURSIVE on sqlite3 as there are distributions
             # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
             for group in groups:
-                group_tree = [group]
                 next_group = group
 
                 while next_group:
+                    # We did this before by getting the list of group ids, and
+                    # then passing that list to sqlite to get latest event for
+                    # each (type, state_key). However, that was terribly slow
+                    # without the right indicies (which we can't add until
+                    # after we finish deduping state, which requires this func)
+                    args = [next_group]
+                    if types:
+                        args.extend(i for typ in types for i in typ)
+
+                    txn.execute(
+                        "SELECT type, state_key, event_id FROM state_groups_state"
+                        " WHERE state_group = ? %s" % (where_clause,),
+                        args
+                    )
+                    rows = txn.fetchall()
+                    results[group].update({
+                        (typ, state_key): event_id
+                        for typ, state_key, event_id in rows
+                        if (typ, state_key) not in results[group]
+                    })
+
+                    # If the lengths match then we must have all the types,
+                    # so no need to go walk further down the tree.
+                    if types is not None and len(results[group]) == len(types):
+                        break
+
                     next_group = self._simple_select_one_onecol_txn(
                         txn,
                         table="state_group_edges",
@@ -386,28 +414,6 @@ class StateStore(SQLBaseStore):
                         retcol="prev_state_group",
                         allow_none=True,
                     )
-                    if next_group:
-                        group_tree.append(next_group)
-
-                sql = ("""
-                    SELECT type, state_key, event_id FROM state_groups_state
-                    INNER JOIN (
-                        SELECT type, state_key, max(state_group) as state_group
-                        FROM state_groups_state
-                        WHERE state_group IN (%s) %s
-                        GROUP BY type, state_key
-                    ) USING (type, state_key, state_group);
-                """) % (",".join("?" for _ in group_tree), where_clause,)
-
-                args = list(group_tree)
-                if types is not None:
-                    args.extend([i for typ in types for i in typ])
-
-                txn.execute(sql, args)
-                rows = self.cursor_to_dict(txn)
-                for row in rows:
-                    key = (row["type"], row["state_key"])
-                    results[group][key] = row["event_id"]
 
         return results