diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index a7fcc564a9..4a4ad0f492 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -93,13 +93,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
- where_clause, where_args = state_filter.make_sql_filter_clause()
-
- # Unless the filter clause is empty, we're going to append it after an
- # existing where clause
- if where_clause:
- where_clause = " AND (%s)" % (where_clause,)
-
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
@@ -110,31 +103,91 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
# against `state_groups_state` to fetch the latest state.
# It assumes that previous state groups are always numerically
# lesser.
- # The PARTITION is used to get the event_id in the greatest state
- # group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
- WITH RECURSIVE state(state_group) AS (
+ WITH RECURSIVE sgs(state_group) AS (
VALUES(?::bigint)
UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
+ SELECT prev_state_group FROM state_group_edges e, sgs s
WHERE s.state_group = e.state_group
)
- SELECT DISTINCT ON (type, state_key)
- type, state_key, event_id
- FROM state_groups_state
- WHERE state_group IN (
- SELECT state_group FROM state
- ) %s
- ORDER BY type, state_key, state_group DESC
+ %s
"""
+ overall_select_query_args: List[Union[int, str]] = []
+
+ # This is an optimization to create a select clause per-condition. This
+ # makes the query planner a lot smarter on what rows should pull out in the
+ # first place and we end up with something that takes 10x less time to get a
+ # result.
+ use_condition_optimization = (
+ not state_filter.include_others and not state_filter.is_full()
+ )
+ state_filter_condition_combos: List[Tuple[str, Optional[str]]] = []
+ # We don't need to caclculate this list if we're not using the condition
+ # optimization
+ if use_condition_optimization:
+ for etype, state_keys in state_filter.types.items():
+ if state_keys is None:
+ state_filter_condition_combos.append((etype, None))
+ else:
+ for state_key in state_keys:
+ state_filter_condition_combos.append((etype, state_key))
+ # And here is the optimization itself. We don't want to do the optimization
+ # if there are too many individual conditions. 10 is an arbitrary number
+ # with no testing behind it but we do know that we specifically made this
+ # optimization for when we grab the necessary state out for
+ # `filter_events_for_client` which just uses 2 conditions
+ # (`EventTypes.RoomHistoryVisibility` and `EventTypes.Member`).
+ if use_condition_optimization and len(state_filter_condition_combos) < 10:
+ select_clause_list: List[str] = []
+ for etype, skey in state_filter_condition_combos:
+ if skey is None:
+ where_clause = "(type = ?)"
+ overall_select_query_args.extend([etype])
+ else:
+ where_clause = "(type = ? AND state_key = ?)"
+ overall_select_query_args.extend([etype, skey])
+
+ select_clause_list.append(
+ f"""
+ (
+ SELECT DISTINCT ON (type, state_key)
+ type, state_key, event_id
+ FROM state_groups_state
+ INNER JOIN sgs USING (state_group)
+ WHERE {where_clause}
+ ORDER BY type, state_key, state_group DESC
+ )
+ """
+ )
+
+ overall_select_clause = " UNION ".join(select_clause_list)
+ else:
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+ # Unless the filter clause is empty, we're going to append it after an
+ # existing where clause
+ if where_clause:
+ where_clause = " AND (%s)" % (where_clause,)
+
+ overall_select_query_args.extend(where_args)
+
+ overall_select_clause = f"""
+ SELECT DISTINCT ON (type, state_key)
+ type, state_key, event_id
+ FROM state_groups_state
+ WHERE state_group IN (
+ SELECT state_group FROM sgs
+ ) {where_clause}
+ ORDER BY type, state_key, state_group DESC
+ """
+
for group in groups:
args: List[Union[int, str]] = [group]
- args.extend(where_args)
+ args.extend(overall_select_query_args)
- txn.execute(sql % (where_clause,), args)
+ txn.execute(sql % (overall_select_clause,), args)
for row in txn:
typ, state_key, event_id = row
key = (intern_string(typ), intern_string(state_key))
@@ -142,6 +195,12 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
else:
max_entries_returned = state_filter.max_entries_returned()
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+ # Unless the filter clause is empty, we're going to append it after an
+ # existing where clause
+ if where_clause:
+ where_clause = " AND (%s)" % (where_clause,)
+
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
|