diff options
Diffstat (limited to 'synapse/push/push_rule_evaluator.py')
-rw-r--r-- | synapse/push/push_rule_evaluator.py | 70 |
1 files changed, 66 insertions, 4 deletions
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index f617c759e6..54db6b5612 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -129,9 +129,55 @@ class PushRuleEvaluatorForEvent: # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) + # Maps cache keys to final values. + self._condition_cache: Dict[str, bool] = {} + + def check_conditions( + self, conditions: List[dict], uid: str, display_name: Optional[str] + ) -> bool: + """ + Returns true if a user's conditions/user ID/display name match the event. + + Args: + conditions: The user's conditions to match. + uid: The user's MXID. + display_name: The display name. + + Returns: + True if all conditions match the event, False otherwise. + """ + for cond in conditions: + _cache_key = cond.get("_cache_key", None) + if _cache_key: + res = self._condition_cache.get(_cache_key, None) + if res is False: + return False + elif res is True: + continue + + res = self.matches(cond, uid, display_name) + if _cache_key: + self._condition_cache[_cache_key] = bool(res) + + if not res: + return False + + return True + def matches( self, condition: Dict[str, Any], user_id: str, display_name: Optional[str] ) -> bool: + """ + Returns true if a user's condition/user ID/display name match the event. + + Args: + condition: The user's condition to match. + uid: The user's MXID. + display_name: The display name, or None if there is not one. + + Returns: + True if the condition matches the event, False otherwise. + """ if condition["kind"] == "event_match": return self._event_match(condition, user_id) elif condition["kind"] == "contains_display_name": @@ -146,6 +192,16 @@ class PushRuleEvaluatorForEvent: return True def _event_match(self, condition: dict, user_id: str) -> bool: + """ + Check an "event_match" push rule condition. + + Args: + condition: The "event_match" push rule condition to match. + user_id: The user's MXID. + + Returns: + True if the condition matches the event, False otherwise. + """ pattern = condition.get("pattern", None) if not pattern: @@ -167,13 +223,22 @@ class PushRuleEvaluatorForEvent: return _glob_matches(pattern, body, word_boundary=True) else: - haystack = self._get_value(condition["key"]) + haystack = self._value_cache.get(condition["key"], None) if haystack is None: return False return _glob_matches(pattern, haystack) def _contains_display_name(self, display_name: Optional[str]) -> bool: + """ + Check an "event_match" push rule condition. + + Args: + display_name: The display name, or None if there is not one. + + Returns: + True if the display name is found in the event body, False otherwise. + """ if not display_name: return False @@ -191,9 +256,6 @@ class PushRuleEvaluatorForEvent: return bool(r.search(body)) - def _get_value(self, dotted_key: str) -> Optional[str]: - return self._value_cache.get(dotted_key, None) - # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache( |