summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/databases/state/bg_updates.py78
1 files changed, 65 insertions, 13 deletions
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index c26860e0d6..73bcc5e613 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, Union
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
@@ -90,8 +90,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
         state_filter: Optional[StateFilter] = None,
     ) -> Mapping[int, StateMap[str]]:
         """
-        We can sort from smallest to largest state_group and re-use the work from the
-        small state_group for a larger one if we see that the edge chain links up.
+        TODO
         """
 
         state_filter = state_filter or StateFilter.all()
@@ -111,11 +110,22 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
             # This may return multiple rows per (type, state_key), but last_value
             # should be the same.
             sql = """
-                WITH RECURSIVE sgs(state_group) AS (
-                    VALUES(?::bigint)
+                WITH RECURSIVE sgs(state_group, state_group_reached) AS (
+                    VALUES(?::bigint, NULL::bigint)
                     UNION ALL
-                    SELECT prev_state_group FROM state_group_edges e, sgs s
-                    WHERE s.state_group = e.state_group
+                    SELECT
+                        prev_state_group,
+                        CASE
+                            /* Specify state_groups we have already done the work for */
+                            WHEN @prev_state_group IN (%s) THEN prev_state_group
+                            ELSE NULL
+                        END AS state_group_reached
+                    FROM
+                        state_group_edges e, sgs s
+                    WHERE
+                        s.state_group = e.state_group
+                        /* Stop when we connect up to another state_group that we already did the work for */
+                        AND s.state_group_reached IS NULL
                 )
                 %s
             """
@@ -159,7 +169,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
                         f"""
                         (
                             SELECT DISTINCT ON (type, state_key)
-                                type, state_key, event_id
+                                type, state_key, event_id, state_group
                             FROM state_groups_state
                             INNER JOIN sgs USING (state_group)
                             WHERE {where_clause}
@@ -180,7 +190,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
 
                 overall_select_clause = f"""
                     SELECT DISTINCT ON (type, state_key)
-                        type, state_key, event_id
+                        type, state_key, event_id, state_group
                     FROM state_groups_state
                     WHERE state_group IN (
                         SELECT state_group FROM sgs
@@ -188,15 +198,57 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
                     ORDER BY type, state_key, state_group DESC
                 """
 
-            for group in groups:
+            # We can sort from smallest to largest state_group and re-use the work from
+            # the small state_group for a larger one if we see that the edge chain links
+            # up.
+            sorted_groups = sorted(groups)
+            state_groups_we_have_already_fetched: Set[int] = set()
+            for group in sorted_groups:
                 args: List[Union[int, str]] = [group]
                 args.extend(overall_select_query_args)
 
-                txn.execute(sql % (overall_select_clause,), args)
+                state_groups_we_have_already_fetched_string = [
+                    f"{state_group}::bigint"
+                    for state_group in state_groups_we_have_already_fetched
+                ].join(", ")
+
+                txn.execute(
+                    sql
+                    % (
+                        state_groups_we_have_already_fetched_string,
+                        overall_select_clause,
+                    ),
+                    args,
+                )
+
+                min_state_group: Optional[int] = None
+                partial_state_map_for_state_group: MutableStateMap[str] = {}
                 for row in txn:
-                    typ, state_key, event_id = row
+                    typ, state_key, event_id, state_group = row
                     key = (intern_string(typ), intern_string(state_key))
-                    results[group][key] = event_id
+                    partial_state_map_for_state_group[key] = event_id
+
+                    if state_group < min_state_group or min_state_group is None:
+                        min_state_group = state_group
+
+                # If we see a state group edge link to a previous state_group that we
+                # already fetched from the database, link up the base state to the
+                # partial state we retrieved from the database to build on top of.
+                if results[min_state_group] is not None:
+                    base_state_map = results[min_state_group].copy()
+
+                    results[group] = base_state_map.update(
+                        partial_state_map_for_state_group
+                    )
+                else:
+                    # It's also completely normal for us not to have a previous
+                    # state_group to build on top of if this is the first group being
+                    # processes or we are processing a bunch of groups from different
+                    # rooms which of course will never link together.
+                    results[group] = partial_state_map_for_state_group
+
+                state_groups_we_have_already_fetched.add(group)
+
         else:
             max_entries_returned = state_filter.max_entries_returned()