summary refs log tree commit diff
path: root/synapse/push/push_rule_evaluator.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/push/push_rule_evaluator.py')
-rw-r--r--synapse/push/push_rule_evaluator.py109
1 files changed, 47 insertions, 62 deletions
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 0816b632b4..78d4b564d4 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -127,7 +127,7 @@ class PushRuleEvaluator:
         room_members = yield self.store.get_users_in_room(room_id)
         room_member_count = len(room_members)
 
-        evaluator = PushRuleEvaluatorForEvent.create(ev, room_member_count)
+        evaluator = PushRuleEvaluatorForEvent(ev, room_member_count)
 
         for r in self.rules:
             if self.enabled_map.get(r['rule_id'], None) is False:
@@ -180,33 +180,13 @@ class PushRuleEvaluator:
 
 
 class PushRuleEvaluatorForEvent(object):
-    WORD_BOUNDARY = re.compile(r'\b')
-
-    def __init__(self, event, body_parts, room_member_count):
+    def __init__(self, event, room_member_count):
         self._event = event
-
-        # This is a list of words of the content.body (if event has one). Each
-        # word has been converted to lower case.
-        self._body_parts = body_parts
-
         self._room_member_count = room_member_count
 
         # Maps strings of e.g. 'content.body' -> event["content"]["body"]
         self._value_cache = _flatten_dict(event)
 
-    @staticmethod
-    def create(event, room_member_count):
-        body = event.get("content", {}).get("body", None)
-        if body:
-            body_parts = PushRuleEvaluatorForEvent.WORD_BOUNDARY.split(body)
-            body_parts[:] = [
-                part.lower() for part in body_parts
-            ]
-        else:
-            body_parts = []
-
-        return PushRuleEvaluatorForEvent(event, body_parts, room_member_count)
-
     def matches(self, condition, user_id, display_name, profile_tag):
         if condition['kind'] == 'event_match':
             return self._event_match(condition, user_id)
@@ -239,67 +219,72 @@ class PushRuleEvaluatorForEvent(object):
 
         # XXX: optimisation: cache our pattern regexps
         if condition['key'] == 'content.body':
-            matcher = _glob_to_matcher(pattern)
+            body = self._event["content"].get("body", None)
+            if not body:
+                return False
 
-            for part in self._body_parts:
-                if matcher(part):
-                    return True
-            return False
+            return _glob_matches(pattern, body, word_boundary=True)
         else:
             haystack = self._get_value(condition['key'])
             if haystack is None:
                 return False
 
-            matcher = _glob_to_matcher(pattern)
-
-            return matcher(haystack.lower())
+            return _glob_matches(pattern, haystack)
 
     def _contains_display_name(self, display_name):
         if not display_name:
             return False
 
-        lower_display_name = display_name.lower()
-        for part in self._body_parts:
-            if part == lower_display_name:
-                return True
+        body = self._event["content"].get("body", None)
+        if not body:
+            return False
 
-        return False
+        return _glob_matches(display_name, body, word_boundary=True)
 
     def _get_value(self, dotted_key):
         return self._value_cache.get(dotted_key, None)
 
 
-def _glob_to_matcher(glob):
-    """Takes a glob and returns a `func(string) -> bool`, which returns if the
-    string matches the glob. Assumes given string is lower case.
-
-    The matcher returned is either a simple string comparison for globs without
-    wildcards, or a regex matcher for globs with wildcards.
-    """
-    glob = glob.lower()
-
-    if not IS_GLOB.search(glob):
-        return lambda value: value == glob
+def _glob_matches(glob, value, word_boundary=False):
+    """Tests if value matches glob.
 
-    r = re.escape(glob)
+    Args:
+        glob (string)
+        value (string): String to test against glob.
+        word_boundary (bool): Whether to match against word boundaries or entire
+            string. Defaults to False.
 
-    r = r.replace(r'\*', '.*?')
-    r = r.replace(r'\?', '.')
+    Returns:
+        bool
+    """
+    if IS_GLOB.search(glob):
+        r = re.escape(glob)
+
+        r = r.replace(r'\*', '.*?')
+        r = r.replace(r'\?', '.')
+
+        # handle [abc], [a-z] and [!a-z] style ranges.
+        r = GLOB_REGEX.sub(
+            lambda x: (
+                '[%s%s]' % (
+                    x.group(1) and '^' or '',
+                    x.group(2).replace(r'\\\-', '-')
+                )
+            ),
+            r,
+        )
+        r = r + "$"
+        r = re.compile(r, flags=re.IGNORECASE)
 
-    # handle [abc], [a-z] and [!a-z] style ranges.
-    r = GLOB_REGEX.sub(
-        lambda x: (
-            '[%s%s]' % (
-                x.group(1) and '^' or '',
-                x.group(2).replace(r'\\\-', '-')
-            )
-        ),
-        r,
-    )
+        return r.match(value)
+    elif word_boundary:
+        r = re.escape(glob)
+        r = "\b%s\b" % (r,)
+        r = re.compile(r, flags=re.IGNORECASE)
 
-    r = r + "$"
-    r = re.compile(r)
-    return lambda value: r.match(value)
+        return r.search(value)
+    else:
+        return value.lower() == glob.lower()
 
 
 def _flatten_dict(d, prefix=[], result={}):