summary refs log tree commit diff
path: root/synapse/events/utils.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/events/utils.py72
1 files changed, 58 insertions, 14 deletions
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index e6d040176b..e7b7b78b84 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -22,6 +22,7 @@ from typing import (
     Iterable,
     List,
     Mapping,
+    Match,
     MutableMapping,
     Optional,
     Union,
@@ -46,12 +47,10 @@ if TYPE_CHECKING:
     from synapse.handlers.relations import BundledAggregations
 
 
-# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
-# (?<!stuff) matches if the current position in the string is not preceded
-# by a match for 'stuff'.
-# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
-#       the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
-SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
+# Split strings on "." but not "\." (or "\\\.").
+SPLIT_FIELD_REGEX = re.compile(r"\\*\.")
+# Find escaped characters, e.g. those with a \ in front of them.
+ESCAPE_SEQUENCE_PATTERN = re.compile(r"\\(.)")
 
 CANONICALJSON_MAX_INT = (2**53) - 1
 CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT
@@ -253,6 +252,57 @@ def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
     sub_out_dict[key_to_move] = sub_dict[key_to_move]
 
 
+def _escape_slash(m: Match[str]) -> str:
+    """
+    Replacement function; replace a backslash-backslash or backslash-dot with the
+    second character. Leaves any other string alone.
+    """
+    if m.group(1) in ("\\", "."):
+        return m.group(1)
+    return m.group(0)
+
+
+def _split_field(field: str) -> List[str]:
+    """
+    Splits strings on unescaped dots and removes escaping.
+
+    Args:
+        field: A string representing a path to a field.
+
+    Returns:
+        A list of nested fields to traverse.
+    """
+
+    # Convert the field and remove escaping:
+    #
+    # 1. "content.body.thing\.with\.dots"
+    # 2. ["content", "body", "thing\.with\.dots"]
+    # 3. ["content", "body", "thing.with.dots"]
+
+    # Find all dots (and their preceding backslashes). If the dot is unescaped
+    # then emit a new field part.
+    result = []
+    prev_start = 0
+    for match in SPLIT_FIELD_REGEX.finditer(field):
+        # If the match is an *even* number of characters than the dot was escaped.
+        if len(match.group()) % 2 == 0:
+            continue
+
+        # Add a new part (up to the dot, exclusive) after escaping.
+        result.append(
+            ESCAPE_SEQUENCE_PATTERN.sub(
+                _escape_slash, field[prev_start : match.end() - 1]
+            )
+        )
+        prev_start = match.end()
+
+    # Add any part of the field after the last unescaped dot. (Note that if the
+    # character is a dot this correctly adds a blank string.)
+    result.append(re.sub(r"\\(.)", _escape_slash, field[prev_start:]))
+
+    return result
+
+
 def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
     """Return a new dict with only the fields in 'dictionary' which are present
     in 'fields'.
@@ -260,7 +310,7 @@ def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
     If there are no event fields specified then all fields are included.
     The entries may include '.' characters to indicate sub-fields.
     So ['content.body'] will include the 'body' field of the 'content' object.
-    A literal '.' character in a field name may be escaped using a '\'.
+    A literal '.' or '\' character in a field name may be escaped using a '\'.
 
     Args:
         dictionary: The dictionary to read from.
@@ -275,13 +325,7 @@ def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
 
     # for each field, convert it:
     # ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]]
-    split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields]
-
-    # for each element of the output array of arrays:
-    # remove escaping so we can use the right key names.
-    split_fields[:] = [
-        [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
-    ]
+    split_fields = [_split_field(f) for f in fields]
 
     output: JsonDict = {}
     for field_array in split_fields: