diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index c26860e0d6..73bcc5e613 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, Union
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -90,8 +90,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
state_filter: Optional[StateFilter] = None,
) -> Mapping[int, StateMap[str]]:
"""
- We can sort from smallest to largest state_group and re-use the work from the
- small state_group for a larger one if we see that the edge chain links up.
+ TODO
"""
state_filter = state_filter or StateFilter.all()
@@ -111,11 +110,22 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
- WITH RECURSIVE sgs(state_group) AS (
- VALUES(?::bigint)
+ WITH RECURSIVE sgs(state_group, state_group_reached) AS (
+ VALUES(?::bigint, NULL::bigint)
UNION ALL
- SELECT prev_state_group FROM state_group_edges e, sgs s
- WHERE s.state_group = e.state_group
+ SELECT
+ prev_state_group,
+ CASE
+ /* Specify state_groups we have already done the work for */
+ WHEN @prev_state_group IN (%s) THEN prev_state_group
+ ELSE NULL
+ END AS state_group_reached
+ FROM
+ state_group_edges e, sgs s
+ WHERE
+ s.state_group = e.state_group
+ /* Stop when we connect up to another state_group that we already did the work for */
+ AND s.state_group_reached IS NULL
)
%s
"""
@@ -159,7 +169,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
f"""
(
SELECT DISTINCT ON (type, state_key)
- type, state_key, event_id
+ type, state_key, event_id, state_group
FROM state_groups_state
INNER JOIN sgs USING (state_group)
WHERE {where_clause}
@@ -180,7 +190,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
overall_select_clause = f"""
SELECT DISTINCT ON (type, state_key)
- type, state_key, event_id
+ type, state_key, event_id, state_group
FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM sgs
@@ -188,15 +198,57 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
ORDER BY type, state_key, state_group DESC
"""
- for group in groups:
+ # We can sort from smallest to largest state_group and re-use the work from
+ # the small state_group for a larger one if we see that the edge chain links
+ # up.
+ sorted_groups = sorted(groups)
+ state_groups_we_have_already_fetched: Set[int] = set()
+ for group in sorted_groups:
args: List[Union[int, str]] = [group]
args.extend(overall_select_query_args)
- txn.execute(sql % (overall_select_clause,), args)
+ state_groups_we_have_already_fetched_string = [
+ f"{state_group}::bigint"
+ for state_group in state_groups_we_have_already_fetched
+ ].join(", ")
+
+ txn.execute(
+ sql
+ % (
+ state_groups_we_have_already_fetched_string,
+ overall_select_clause,
+ ),
+ args,
+ )
+
+ min_state_group: Optional[int] = None
+ partial_state_map_for_state_group: MutableStateMap[str] = {}
for row in txn:
- typ, state_key, event_id = row
+ typ, state_key, event_id, state_group = row
key = (intern_string(typ), intern_string(state_key))
- results[group][key] = event_id
+ partial_state_map_for_state_group[key] = event_id
+
+ if state_group < min_state_group or min_state_group is None:
+ min_state_group = state_group
+
+ # If we see a state group edge link to a previous state_group that we
+ # already fetched from the database, link up the base state to the
+ # partial state we retrieved from the database to build on top of.
+ if results[min_state_group] is not None:
+ base_state_map = results[min_state_group].copy()
+
+ results[group] = base_state_map.update(
+ partial_state_map_for_state_group
+ )
+ else:
+ # It's also completely normal for us not to have a previous
+ # state_group to build on top of if this is the first group being
+ # processes or we are processing a bunch of groups from different
+ # rooms which of course will never link together.
+ results[group] = partial_state_map_for_state_group
+
+ state_groups_we_have_already_fetched.add(group)
+
else:
max_entries_returned = state_filter.max_entries_returned()
|