summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/databases/state/bg_updates.py99
1 files changed, 79 insertions, 20 deletions
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index a7fcc564a9..4a4ad0f492 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -93,13 +93,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
 
         results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
 
-        where_clause, where_args = state_filter.make_sql_filter_clause()
-
-        # Unless the filter clause is empty, we're going to append it after an
-        # existing where clause
-        if where_clause:
-            where_clause = " AND (%s)" % (where_clause,)
-
         if isinstance(self.database_engine, PostgresEngine):
             # Temporarily disable sequential scans in this transaction. This is
             # a temporary hack until we can add the right indices in
@@ -110,31 +103,91 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
             # against `state_groups_state` to fetch the latest state.
             # It assumes that previous state groups are always numerically
             # lesser.
-            # The PARTITION is used to get the event_id in the greatest state
-            # group for the given type, state_key.
             # This may return multiple rows per (type, state_key), but last_value
             # should be the same.
             sql = """
-                WITH RECURSIVE state(state_group) AS (
+                WITH RECURSIVE sgs(state_group) AS (
                     VALUES(?::bigint)
                     UNION ALL
-                    SELECT prev_state_group FROM state_group_edges e, state s
+                    SELECT prev_state_group FROM state_group_edges e, sgs s
                     WHERE s.state_group = e.state_group
                 )
-                SELECT DISTINCT ON (type, state_key)
-                    type, state_key, event_id
-                FROM state_groups_state
-                WHERE state_group IN (
-                    SELECT state_group FROM state
-                ) %s
-                ORDER BY type, state_key, state_group DESC
+                %s
             """
 
+            overall_select_query_args: List[Union[int, str]] = []
+
+            # This is an optimization to create a select clause per-condition. This
+            # makes the query planner a lot smarter on what rows should pull out in the
+            # first place and we end up with something that takes 10x less time to get a
+            # result.
+            use_condition_optimization = (
+                not state_filter.include_others and not state_filter.is_full()
+            )
+            state_filter_condition_combos: List[Tuple[str, Optional[str]]] = []
+            # We don't need to caclculate this list if we're not using the condition
+            # optimization
+            if use_condition_optimization:
+                for etype, state_keys in state_filter.types.items():
+                    if state_keys is None:
+                        state_filter_condition_combos.append((etype, None))
+                    else:
+                        for state_key in state_keys:
+                            state_filter_condition_combos.append((etype, state_key))
+            # And here is the optimization itself. We don't want to do the optimization
+            # if there are too many individual conditions. 10 is an arbitrary number
+            # with no testing behind it but we do know that we specifically made this
+            # optimization for when we grab the necessary state out for
+            # `filter_events_for_client` which just uses 2 conditions
+            # (`EventTypes.RoomHistoryVisibility` and `EventTypes.Member`).
+            if use_condition_optimization and len(state_filter_condition_combos) < 10:
+                select_clause_list: List[str] = []
+                for etype, skey in state_filter_condition_combos:
+                    if skey is None:
+                        where_clause = "(type = ?)"
+                        overall_select_query_args.extend([etype])
+                    else:
+                        where_clause = "(type = ? AND state_key = ?)"
+                        overall_select_query_args.extend([etype, skey])
+
+                    select_clause_list.append(
+                        f"""
+                        (
+                            SELECT DISTINCT ON (type, state_key)
+                                type, state_key, event_id
+                            FROM state_groups_state
+                            INNER JOIN sgs USING (state_group)
+                            WHERE {where_clause}
+                            ORDER BY type, state_key, state_group DESC
+                        )
+                        """
+                    )
+
+                overall_select_clause = " UNION ".join(select_clause_list)
+            else:
+                where_clause, where_args = state_filter.make_sql_filter_clause()
+                # Unless the filter clause is empty, we're going to append it after an
+                # existing where clause
+                if where_clause:
+                    where_clause = " AND (%s)" % (where_clause,)
+
+                overall_select_query_args.extend(where_args)
+
+                overall_select_clause = f"""
+                    SELECT DISTINCT ON (type, state_key)
+                        type, state_key, event_id
+                    FROM state_groups_state
+                    WHERE state_group IN (
+                        SELECT state_group FROM sgs
+                    ) {where_clause}
+                    ORDER BY type, state_key, state_group DESC
+                """
+
             for group in groups:
                 args: List[Union[int, str]] = [group]
-                args.extend(where_args)
+                args.extend(overall_select_query_args)
 
-                txn.execute(sql % (where_clause,), args)
+                txn.execute(sql % (overall_select_clause,), args)
                 for row in txn:
                     typ, state_key, event_id = row
                     key = (intern_string(typ), intern_string(state_key))
@@ -142,6 +195,12 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
         else:
             max_entries_returned = state_filter.max_entries_returned()
 
+            where_clause, where_args = state_filter.make_sql_filter_clause()
+            # Unless the filter clause is empty, we're going to append it after an
+            # existing where clause
+            if where_clause:
+                where_clause = " AND (%s)" % (where_clause,)
+
             # We don't use WITH RECURSIVE on sqlite3 as there are distributions
             # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
             for group in groups: