summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py56
-rw-r--r--synapse/storage/databases/main/push_rule.py50
-rw-r--r--synapse/util/async_helpers.py14
3 files changed, 86 insertions, 34 deletions
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 14784312dc..5934b1ef34 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -25,10 +25,13 @@ from typing import (
     Sequence,
     Tuple,
     Union,
+    cast,
 )
 
 from prometheus_client import Counter
 
+from twisted.internet.defer import Deferred
+
 from synapse.api.constants import (
     MAIN_TIMELINE,
     EventContentFields,
@@ -40,11 +43,15 @@ from synapse.api.room_versions import PushRuleRoomFlag
 from synapse.event_auth import auth_types_for_event, get_user_power_level
 from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.state import POWER_KEY
 from synapse.storage.databases.main.roommember import EventIdMembership
+from synapse.storage.roommember import ProfileInfo
 from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
 from synapse.types import JsonValue
 from synapse.types.state import StateFilter
+from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import gather_results
 from synapse.util.caches import register_cache
 from synapse.util.metrics import measure_func
 from synapse.visibility import filter_event_for_clients_with_state
@@ -342,15 +349,41 @@ class BulkPushRuleEvaluator:
         rules_by_user = await self._get_rules_for_event(event)
         actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {}
 
-        room_member_count = await self.store.get_number_joined_users_in_room(
-            event.room_id
-        )
-
+        # Gather a bunch of info in parallel.
+        #
+        # This has a lot of ignored types and casting due to the use of @cached
+        # decorated functions passed into run_in_background.
+        #
+        # See https://github.com/matrix-org/synapse/issues/16606
         (
-            power_levels,
-            sender_power_level,
-        ) = await self._get_power_levels_and_sender_level(
-            event, context, event_id_to_event
+            room_member_count,
+            (power_levels, sender_power_level),
+            related_events,
+            profiles,
+        ) = await make_deferred_yieldable(
+            cast(
+                "Deferred[Tuple[int, Tuple[dict, Optional[int]], Dict[str, Dict[str, JsonValue]], Mapping[str, ProfileInfo]]]",
+                gather_results(
+                    (
+                        run_in_background(  # type: ignore[call-arg]
+                            self.store.get_number_joined_users_in_room, event.room_id  # type: ignore[arg-type]
+                        ),
+                        run_in_background(
+                            self._get_power_levels_and_sender_level,
+                            event,
+                            context,
+                            event_id_to_event,
+                        ),
+                        run_in_background(self._related_events, event),
+                        run_in_background(  # type: ignore[call-arg]
+                            self.store.get_subset_users_in_room_with_profiles,
+                            event.room_id,  # type: ignore[arg-type]
+                            rules_by_user.keys(),  # type: ignore[arg-type]
+                        ),
+                    ),
+                    consumeErrors=True,
+                ).addErrback(unwrapFirstError),
+            )
         )
 
         # Find the event's thread ID.
@@ -366,8 +399,6 @@ class BulkPushRuleEvaluator:
                 # the parent is part of a thread.
                 thread_id = await self.store.get_thread_id(relation.parent_id)
 
-        related_events = await self._related_events(event)
-
         # It's possible that old room versions have non-integer power levels (floats or
         # strings; even the occasional `null`). For old rooms, we interpret these as if
         # they were integers. Do this here for the `@room` power level threshold.
@@ -400,11 +431,6 @@ class BulkPushRuleEvaluator:
             self.hs.config.experimental.msc1767_enabled,  # MSC3931 flag
         )
 
-        users = rules_by_user.keys()
-        profiles = await self.store.get_subset_users_in_room_with_profiles(
-            event.room_id, users
-        )
-
         for uid, rules in rules_by_user.items():
             if event.sender == uid:
                 continue
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 22025eca56..37135d431d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,8 +28,11 @@ from typing import (
     cast,
 )
 
+from twisted.internet import defer
+
 from synapse.api.errors import StoreError
 from synapse.config.homeserver import ExperimentalConfig
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.replication.tcp.streams import PushRulesStream
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
@@ -51,7 +54,8 @@ from synapse.storage.util.id_generators import (
 )
 from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
 from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import json_encoder, unwrapFirstError
+from synapse.util.async_helpers import gather_results
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -249,23 +253,33 @@ class PushRulesWorkerStore(
             user_id: [] for user_id in user_ids
         }
 
-        rows = cast(
-            List[Tuple[str, str, int, int, str, str]],
-            await self.db_pool.simple_select_many_batch(
-                table="push_rules",
-                column="user_name",
-                iterable=user_ids,
-                retcols=(
-                    "user_name",
-                    "rule_id",
-                    "priority_class",
-                    "priority",
-                    "conditions",
-                    "actions",
+        # gatherResults loses all type information.
+        rows, enabled_map_by_user = await make_deferred_yieldable(
+            gather_results(
+                (
+                    cast(
+                        "defer.Deferred[List[Tuple[str, str, int, int, str, str]]]",
+                        run_in_background(
+                            self.db_pool.simple_select_many_batch,
+                            table="push_rules",
+                            column="user_name",
+                            iterable=user_ids,
+                            retcols=(
+                                "user_name",
+                                "rule_id",
+                                "priority_class",
+                                "priority",
+                                "conditions",
+                                "actions",
+                            ),
+                            desc="bulk_get_push_rules",
+                            batch_size=1000,
+                        ),
+                    ),
+                    run_in_background(self.bulk_get_push_rules_enabled, user_ids),
                 ),
-                desc="bulk_get_push_rules",
-                batch_size=1000,
-            ),
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
         )
 
         # Sort by highest priority_class, then highest priority.
@@ -276,8 +290,6 @@ class PushRulesWorkerStore(
                 (rule_id, priority_class, conditions, actions)
             )
 
-        enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
-
         results: Dict[str, FilteredPushRules] = {}
 
         for user_id, rules in raw_rules.items():
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 0cbeb0c365..8a55e4e41d 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -345,6 +345,7 @@ async def yieldable_gather_results_delaying_cancellation(
 T1 = TypeVar("T1")
 T2 = TypeVar("T2")
 T3 = TypeVar("T3")
+T4 = TypeVar("T4")
 
 
 @overload
@@ -380,6 +381,19 @@ def gather_results(
     ...
 
 
+@overload
+def gather_results(
+    deferredList: Tuple[
+        "defer.Deferred[T1]",
+        "defer.Deferred[T2]",
+        "defer.Deferred[T3]",
+        "defer.Deferred[T4]",
+    ],
+    consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1, T2, T3, T4]]":
+    ...
+
+
 def gather_results(  # type: ignore[misc]
     deferredList: Tuple["defer.Deferred[T1]", ...],
     consumeErrors: bool = False,