summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/push_rule.py50
1 files changed, 31 insertions, 19 deletions
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 22025eca56..37135d431d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,8 +28,11 @@ from typing import (
     cast,
 )
 
+from twisted.internet import defer
+
 from synapse.api.errors import StoreError
 from synapse.config.homeserver import ExperimentalConfig
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.replication.tcp.streams import PushRulesStream
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
@@ -51,7 +54,8 @@ from synapse.storage.util.id_generators import (
 )
 from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
 from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import json_encoder, unwrapFirstError
+from synapse.util.async_helpers import gather_results
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -249,23 +253,33 @@ class PushRulesWorkerStore(
             user_id: [] for user_id in user_ids
         }
 
-        rows = cast(
-            List[Tuple[str, str, int, int, str, str]],
-            await self.db_pool.simple_select_many_batch(
-                table="push_rules",
-                column="user_name",
-                iterable=user_ids,
-                retcols=(
-                    "user_name",
-                    "rule_id",
-                    "priority_class",
-                    "priority",
-                    "conditions",
-                    "actions",
+        # gatherResults loses all type information.
+        rows, enabled_map_by_user = await make_deferred_yieldable(
+            gather_results(
+                (
+                    cast(
+                        "defer.Deferred[List[Tuple[str, str, int, int, str, str]]]",
+                        run_in_background(
+                            self.db_pool.simple_select_many_batch,
+                            table="push_rules",
+                            column="user_name",
+                            iterable=user_ids,
+                            retcols=(
+                                "user_name",
+                                "rule_id",
+                                "priority_class",
+                                "priority",
+                                "conditions",
+                                "actions",
+                            ),
+                            desc="bulk_get_push_rules",
+                            batch_size=1000,
+                        ),
+                    ),
+                    run_in_background(self.bulk_get_push_rules_enabled, user_ids),
                 ),
-                desc="bulk_get_push_rules",
-                batch_size=1000,
-            ),
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
         )
 
         # Sort by highest priority_class, then highest priority.
@@ -276,8 +290,6 @@ class PushRulesWorkerStore(
                 (rule_id, priority_class, conditions, actions)
             )
 
-        enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
-
         results: Dict[str, FilteredPushRules] = {}
 
         for user_id, rules in raw_rules.items():