summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7271.bugfix1
-rw-r--r--synapse/push/push_rule_evaluator.py69
-rw-r--r--tests/push/test_push_rule_evaluator.py65
-rw-r--r--tox.ini1
4 files changed, 106 insertions, 30 deletions
diff --git a/changelog.d/7271.bugfix b/changelog.d/7271.bugfix
new file mode 100644
index 0000000000..e8315e4ce4
--- /dev/null
+++ b/changelog.d/7271.bugfix
@@ -0,0 +1 @@
+Do not treat display names as globs in push rules.
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index b1587183a8..4cd702b5fa 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,9 +16,11 @@
 
 import logging
 import re
+from typing import Pattern
 
 from six import string_types
 
+from synapse.events import EventBase
 from synapse.types import UserID
 from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
 from synapse.util.caches.lrucache import LruCache
@@ -56,18 +58,18 @@ def _test_ineq_condition(condition, number):
     rhs = m.group(2)
     if not rhs.isdigit():
         return False
-    rhs = int(rhs)
+    rhs_int = int(rhs)
 
     if ineq == "" or ineq == "==":
-        return number == rhs
+        return number == rhs_int
     elif ineq == "<":
-        return number < rhs
+        return number < rhs_int
     elif ineq == ">":
-        return number > rhs
+        return number > rhs_int
     elif ineq == ">=":
-        return number >= rhs
+        return number >= rhs_int
     elif ineq == "<=":
-        return number <= rhs
+        return number <= rhs_int
     else:
         return False
 
@@ -83,7 +85,13 @@ def tweaks_for_actions(actions):
 
 
 class PushRuleEvaluatorForEvent(object):
-    def __init__(self, event, room_member_count, sender_power_level, power_levels):
+    def __init__(
+        self,
+        event: EventBase,
+        room_member_count: int,
+        sender_power_level: int,
+        power_levels: dict,
+    ):
         self._event = event
         self._room_member_count = room_member_count
         self._sender_power_level = sender_power_level
@@ -92,7 +100,7 @@ class PushRuleEvaluatorForEvent(object):
         # Maps strings of e.g. 'content.body' -> event["content"]["body"]
         self._value_cache = _flatten_dict(event)
 
-    def matches(self, condition, user_id, display_name):
+    def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
         if condition["kind"] == "event_match":
             return self._event_match(condition, user_id)
         elif condition["kind"] == "contains_display_name":
@@ -106,7 +114,7 @@ class PushRuleEvaluatorForEvent(object):
         else:
             return True
 
-    def _event_match(self, condition, user_id):
+    def _event_match(self, condition: dict, user_id: str) -> bool:
         pattern = condition.get("pattern", None)
 
         if not pattern:
@@ -134,7 +142,7 @@ class PushRuleEvaluatorForEvent(object):
 
             return _glob_matches(pattern, haystack)
 
-    def _contains_display_name(self, display_name):
+    def _contains_display_name(self, display_name: str) -> bool:
         if not display_name:
             return False
 
@@ -142,51 +150,52 @@ class PushRuleEvaluatorForEvent(object):
         if not body:
             return False
 
-        return _glob_matches(display_name, body, word_boundary=True)
+        # Similar to _glob_matches, but do not treat display_name as a glob.
+        r = regex_cache.get((display_name, False, True), None)
+        if not r:
+            r = re.escape(display_name)
+            r = _re_word_boundary(r)
+            r = re.compile(r, flags=re.IGNORECASE)
+            regex_cache[(display_name, False, True)] = r
+
+        return r.search(body)
 
-    def _get_value(self, dotted_key):
+    def _get_value(self, dotted_key: str) -> str:
         return self._value_cache.get(dotted_key, None)
 
 
-# Caches (glob, word_boundary) -> regex for push. See _glob_matches
+# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
 regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
 register_cache("cache", "regex_push_cache", regex_cache)
 
 
-def _glob_matches(glob, value, word_boundary=False):
+def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
     """Tests if value matches glob.
 
     Args:
-        glob (string)
-        value (string): String to test against glob.
-        word_boundary (bool): Whether to match against word boundaries or entire
+        glob
+        value: String to test against glob.
+        word_boundary: Whether to match against word boundaries or entire
             string. Defaults to False.
-
-    Returns:
-        bool
     """
 
     try:
-        r = regex_cache.get((glob, word_boundary), None)
+        r = regex_cache.get((glob, True, word_boundary), None)
         if not r:
             r = _glob_to_re(glob, word_boundary)
-            regex_cache[(glob, word_boundary)] = r
+            regex_cache[(glob, True, word_boundary)] = r
         return r.search(value)
     except re.error:
         logger.warning("Failed to parse glob to regex: %r", glob)
         return False
 
 
-def _glob_to_re(glob, word_boundary):
+def _glob_to_re(glob: str, word_boundary: bool) -> Pattern:
     """Generates regex for a given glob.
 
     Args:
-        glob (string)
-        word_boundary (bool): Whether to match against word boundaries or entire
-            string. Defaults to False.
-
-    Returns:
-        regex object
+        glob
+        word_boundary: Whether to match against word boundaries or entire string.
     """
     if IS_GLOB.search(glob):
         r = re.escape(glob)
@@ -219,7 +228,7 @@ def _glob_to_re(glob, word_boundary):
         return re.compile(r, flags=re.IGNORECASE)
 
 
-def _re_word_boundary(r):
+def _re_word_boundary(r: str) -> str:
     """
     Adds word boundary characters to the start and end of an
     expression to require that the match occur as a whole word,
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
new file mode 100644
index 0000000000..9ae6a87d7b
--- /dev/null
+++ b/tests/push/test_push_rule_evaluator.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.api.room_versions import RoomVersions
+from synapse.events import FrozenEvent
+from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
+
+from tests import unittest
+
+
+class PushRuleEvaluatorTestCase(unittest.TestCase):
+    def setUp(self):
+        event = FrozenEvent(
+            {
+                "event_id": "$event_id",
+                "type": "m.room.history_visibility",
+                "sender": "@user:test",
+                "state_key": "",
+                "room_id": "@room:test",
+                "content": {"body": "foo bar baz"},
+            },
+            RoomVersions.V1,
+        )
+        room_member_count = 0
+        sender_power_level = 0
+        power_levels = {}
+        self.evaluator = PushRuleEvaluatorForEvent(
+            event, room_member_count, sender_power_level, power_levels
+        )
+
+    def test_display_name(self):
+        """Check for a matching display name in the body of the event."""
+        condition = {
+            "kind": "contains_display_name",
+        }
+
+        # Blank names are skipped.
+        self.assertFalse(self.evaluator.matches(condition, "@user:test", ""))
+
+        # Check a display name that doesn't match.
+        self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found"))
+
+        # Check a display name which matches.
+        self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo"))
+
+        # A display name that matches, but not a full word does not result in a match.
+        self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba"))
+
+        # A display name should not be interpreted as a regular expression.
+        self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]"))
+
+        # A display name with spaces should work fine.
+        self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar"))
diff --git a/tox.ini b/tox.ini
index 763c8463d9..42b2d74891 100644
--- a/tox.ini
+++ b/tox.ini
@@ -196,6 +196,7 @@ commands = mypy \
             synapse/metrics \
             synapse/module_api \
             synapse/push/pusherpool.py \
+            synapse/push/push_rule_evaluator.py \
             synapse/replication \
             synapse/rest \
             synapse/spam_checker_api \