summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorŠimon Brandner <simon.bra.ag@gmail.com>2022-05-10 09:57:36 +0200
committerGitHub <noreply@github.com>2022-05-10 08:57:36 +0100
commitade30088212091284d066fc31a755a8a21050677 (patch)
tree5398226fb0090584da6856db2a61b3fe1cad8056 /synapse/storage
parentUpdate `replication.md` with info on TCP module structure (#12621) (diff)
downloadsynapse-ade30088212091284d066fc31a755a8a21050677.tar.xz
Implement MSC3786: Add a default push rule to ignore m.room.server_acl events (#12601)
Fixes vector-im/element-web#20788
Implements matrix-org/matrix-spec-proposals#3786
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/push_rule.py48
1 files changed, 37 insertions, 11 deletions
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index eb85bbd392..4ed913e248 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -17,6 +17,7 @@ import logging
 from typing import TYPE_CHECKING, Dict, List, Tuple, Union
 
 from synapse.api.errors import StoreError
+from synapse.config.homeserver import ExperimentalConfig
 from synapse.push.baserules import list_with_base_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -42,7 +43,21 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-def _load_rules(rawrules, enabled_map):
+def _is_experimental_rule_enabled(
+    rule_id: str, experimental_config: ExperimentalConfig
+) -> bool:
+    """Used by `_load_rules` to filter out experimental rules when they
+    have not been enabled.
+    """
+    if (
+        rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
+        and not experimental_config.msc3786_enabled
+    ):
+        return False
+    return True
+
+
+def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig):
     ruleslist = []
     for rawrule in rawrules:
         rule = dict(rawrule)
@@ -51,17 +66,26 @@ def _load_rules(rawrules, enabled_map):
         rule["default"] = False
         ruleslist.append(rule)
 
-    # We're going to be mutating this a lot, so do a deep copy
-    rules = list(list_with_base_rules(ruleslist))
+    # We're going to be mutating this a lot, so copy it. We also filter out
+    # any experimental default push rules that aren't enabled.
+    rules = [
+        rule
+        for rule in list_with_base_rules(ruleslist)
+        if _is_experimental_rule_enabled(rule["rule_id"], experimental_config)
+    ]
 
     for i, rule in enumerate(rules):
         rule_id = rule["rule_id"]
-        if rule_id in enabled_map:
-            if rule.get("enabled", True) != bool(enabled_map[rule_id]):
-                # Rules are cached across users.
-                rule = dict(rule)
-                rule["enabled"] = bool(enabled_map[rule_id])
-                rules[i] = rule
+
+        if rule_id not in enabled_map:
+            continue
+        if rule.get("enabled", True) == bool(enabled_map[rule_id]):
+            continue
+
+        # Rules are cached across users.
+        rule = dict(rule)
+        rule["enabled"] = bool(enabled_map[rule_id])
+        rules[i] = rule
 
     return rules
 
@@ -141,7 +165,7 @@ class PushRulesWorkerStore(
 
         enabled_map = await self.get_push_rules_enabled_for_user(user_id)
 
-        return _load_rules(rows, enabled_map)
+        return _load_rules(rows, enabled_map, self.hs.config.experimental)
 
     @cached(max_entries=5000)
     async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
@@ -200,7 +224,9 @@ class PushRulesWorkerStore(
         enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
 
         for user_id, rules in results.items():
-            results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
+            results[user_id] = _load_rules(
+                rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
+            )
 
         return results