summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/push_rule.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 264521635f..19a0211a03 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -38,7 +38,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
-def _load_rules(rawrules, enabled_map):
+def _load_rules(rawrules, enabled_map, use_new_defaults=False):
     ruleslist = []
     for rawrule in rawrules:
         rule = dict(rawrule)
@@ -48,7 +48,7 @@ def _load_rules(rawrules, enabled_map):
         ruleslist.append(rule)
 
     # We're going to be mutating this a lot, so do a deep copy
-    rules = list(list_with_base_rules(ruleslist))
+    rules = list(list_with_base_rules(ruleslist, use_new_defaults))
 
     for i, rule in enumerate(rules):
         rule_id = rule["rule_id"]
@@ -104,6 +104,8 @@ class PushRulesWorkerStore(
             prefilled_cache=push_rules_prefill,
         )
 
+        self._users_new_default_push_rules = hs.config.users_new_default_push_rules
+
     @abc.abstractmethod
     def get_max_push_rules_stream_id(self):
         """Get the position of the push rules stream.
@@ -133,7 +135,9 @@ class PushRulesWorkerStore(
 
         enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
 
-        rules = _load_rules(rows, enabled_map)
+        use_new_defaults = user_id in self._users_new_default_push_rules
+
+        rules = _load_rules(rows, enabled_map, use_new_defaults)
 
         return rules
 
@@ -193,7 +197,11 @@ class PushRulesWorkerStore(
         enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
 
         for user_id, rules in results.items():
-            results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
+            use_new_defaults = user_id in self._users_new_default_push_rules
+
+            results[user_id] = _load_rules(
+                rules, enabled_map_by_user.get(user_id, {}), use_new_defaults,
+            )
 
         return results