diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 768f95d16c..12ad44dbb3 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -12,15 +12,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import abc
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ cast,
+)
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
-from synapse.push.baserules import list_with_base_rules
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.replication.tcp.streams import PushRulesStream
+from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -39,6 +50,7 @@ from synapse.storage.util.id_generators import (
IdGenerator,
StreamIdGenerator,
)
+from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -50,64 +62,34 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def _is_experimental_rule_enabled(
- rule_id: str, experimental_config: ExperimentalConfig
-) -> bool:
- """Used by `_load_rules` to filter out experimental rules when they
- have not been enabled.
- """
- if (
- rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
- and not experimental_config.msc3786_enabled
- ):
- return False
- if (
- rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
- and not experimental_config.msc3772_enabled
- ):
- return False
- return True
-
-
def _load_rules(
rawrules: List[JsonDict],
enabled_map: Dict[str, bool],
experimental_config: ExperimentalConfig,
-) -> List[JsonDict]:
- ruleslist = []
- for rawrule in rawrules:
- rule = dict(rawrule)
- rule["conditions"] = db_to_json(rawrule["conditions"])
- rule["actions"] = db_to_json(rawrule["actions"])
- rule["default"] = False
- ruleslist.append(rule)
-
- # We're going to be mutating this a lot, so copy it. We also filter out
- # any experimental default push rules that aren't enabled.
- rules = [
- rule
- for rule in list_with_base_rules(ruleslist)
- if _is_experimental_rule_enabled(rule["rule_id"], experimental_config)
- ]
+) -> FilteredPushRules:
+ """Take the DB rows returned from the DB and convert them into a full
+ `FilteredPushRules` object.
+ """
- for i, rule in enumerate(rules):
- rule_id = rule["rule_id"]
+ ruleslist = [
+ PushRule.from_db(
+ rule_id=rawrule["rule_id"],
+ priority_class=rawrule["priority_class"],
+ conditions=rawrule["conditions"],
+ actions=rawrule["actions"],
+ )
+ for rawrule in rawrules
+ ]
- if rule_id not in enabled_map:
- continue
- if rule.get("enabled", True) == bool(enabled_map[rule_id]):
- continue
+ push_rules = PushRules(ruleslist)
- # Rules are cached across users.
- rule = dict(rule)
- rule["enabled"] = bool(enabled_map[rule_id])
- rules[i] = rule
+ filtered_rules = FilteredPushRules(
+ push_rules, enabled_map, msc3664_enabled=experimental_config.msc3664_enabled
+ )
- return rules
+ return filtered_rules
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
class PushRulesWorkerStore(
ApplicationServiceWorkerStore,
PusherWorkerStore,
@@ -115,7 +97,6 @@ class PushRulesWorkerStore(
ReceiptsWorkerStore,
EventsWorkerStore,
SQLBaseStore,
- metaclass=abc.ABCMeta,
):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
@@ -129,14 +110,14 @@ class PushRulesWorkerStore(
):
super().__init__(database, db_conn, hs)
- if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
- db_conn, "push_rules_stream", "stream_id"
- )
- else:
- self._push_rules_stream_id_gen = SlavedIdTracker(
- db_conn, "push_rules_stream", "stream_id"
- )
+ # In the worker store this is an ID tracker which we overwrite in the non-worker
+ # class below that is used on the main process.
+ self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ db_conn,
+ "push_rules_stream",
+ "stream_id",
+ is_writer=hs.config.worker.worker_app is None,
+ )
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
db_conn,
@@ -152,17 +133,26 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
- @abc.abstractmethod
def get_max_push_rules_stream_id(self) -> int:
"""Get the position of the push rules stream.
Returns:
int
"""
- raise NotImplementedError()
+ return self._push_rules_stream_id_gen.get_current_token()
+
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+ ) -> None:
+ if stream_name == PushRulesStream.NAME:
+ self._push_rules_stream_id_gen.advance(instance_name, token)
+ for row in rows:
+ self.get_push_rules_for_user.invalidate((row.user_id,))
+ self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
@cached(max_entries=5000)
- async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
+ async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
@@ -183,7 +173,6 @@ class PushRulesWorkerStore(
return _load_rules(rows, enabled_map, self.hs.config.experimental)
- @cached(max_entries=5000)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
table="push_rules_enable",
@@ -216,11 +205,11 @@ class PushRulesWorkerStore(
@cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
async def bulk_get_push_rules(
self, user_ids: Collection[str]
- ) -> Dict[str, List[JsonDict]]:
+ ) -> Dict[str, FilteredPushRules]:
if not user_ids:
return {}
- results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+ raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
@@ -234,20 +223,19 @@ class PushRulesWorkerStore(
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
for row in rows:
- results.setdefault(row["user_name"], []).append(row)
+ raw_rules.setdefault(row["user_name"], []).append(row)
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
- for user_id, rules in results.items():
+ results: Dict[str, FilteredPushRules] = {}
+
+ for user_id, rules in raw_rules.items():
results[user_id] = _load_rules(
rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
)
return results
- @cachedList(
- cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids"
- )
async def bulk_get_push_rules_enabled(
self, user_ids: Collection[str]
) -> Dict[str, Dict[str, bool]]:
@@ -262,6 +250,7 @@ class PushRulesWorkerStore(
iterable=user_ids,
retcols=("user_name", "rule_id", "enabled"),
desc="bulk_get_push_rules_enabled",
+ batch_size=1000,
)
for row in rows:
enabled = bool(row["enabled"])
@@ -345,8 +334,8 @@ class PushRuleStore(PushRulesWorkerStore):
user_id: str,
rule_id: str,
priority_class: int,
- conditions: List[Dict[str, str]],
- actions: List[Union[JsonDict, str]],
+ conditions: Sequence[Mapping[str, str]],
+ actions: Sequence[Union[Mapping[str, Any], str]],
before: Optional[str] = None,
after: Optional[str] = None,
) -> None:
@@ -808,7 +797,6 @@ class PushRuleStore(PushRulesWorkerStore):
self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
- txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
txn.call_after(
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
@@ -817,7 +805,7 @@ class PushRuleStore(PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
async def copy_push_rule_from_room_to_room(
- self, new_room_id: str, user_id: str, rule: dict
+ self, new_room_id: str, user_id: str, rule: PushRule
) -> None:
"""Copy a single push rule from one room to another for a specific user.
@@ -827,21 +815,27 @@ class PushRuleStore(PushRulesWorkerStore):
rule: A push rule.
"""
# Create new rule id
- rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+ rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id
+ new_conditions = []
+
# Change room id in each condition
- for condition in rule.get("conditions", []):
+ for condition in rule.conditions:
+ new_condition = condition
if condition.get("key") == "room_id":
- condition["pattern"] = new_room_id
+ new_condition = dict(condition)
+ new_condition["pattern"] = new_room_id
+
+ new_conditions.append(new_condition)
# Add the rule for the new room
await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
- priority_class=rule["priority_class"],
- conditions=rule["conditions"],
- actions=rule["actions"],
+ priority_class=rule.priority_class,
+ conditions=new_conditions,
+ actions=rule.actions,
)
async def copy_push_rules_from_room_to_room_for_user(
@@ -859,8 +853,11 @@ class PushRuleStore(PushRulesWorkerStore):
user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
- for rule in user_push_rules:
- conditions = rule.get("conditions", [])
+ for rule, enabled in user_push_rules.rules():
+ if not enabled:
+ continue
+
+ conditions = rule.conditions
if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
|