summary refs log tree commit diff
diff options
context:
space:
mode:
authorKegan Dougal <kegan@matrix.org>2016-11-21 17:42:16 +0000
committerKegan Dougal <kegan@matrix.org>2016-11-21 17:42:16 +0000
commitf97511a1f3197c6011b5ef7a363885dde9939d6b (patch)
tree1c726c95a488df8dad25532a817b09812504b148
parentAdd filter_event_fields and filter_field to FilterCollection (diff)
downloadsynapse-f97511a1f3197c6011b5ef7a363885dde9939d6b.tar.xz
Move event_fields filtering to serialize_event
Also make it an inclusive not exclusive filter, as the spec demands.
-rw-r--r--synapse/api/filtering.py56
-rw-r--r--synapse/events/utils.py101
-rw-r--r--tests/events/test_utils.py21
3 files changed, 119 insertions, 59 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 27f8b99e3d..4fd0e2d9fa 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -18,7 +18,6 @@ from synapse.types import UserID, RoomID
 from twisted.internet import defer
 
 import ujson as json
-import re
 
 
 class Filtering(object):
@@ -81,7 +80,7 @@ class Filtering(object):
                 # Don't allow '\\' in event field filters. This makes matching
                 # events a lot easier as we can then use a negative lookbehind
                 # assertion to split '\.' If we allowed \\ then it would
-                # incorrectly split '\\.'
+                # incorrectly split '\\.' See synapse.events.utils.serialize_event
                 if r'\\' in field:
                     raise SynapseError(
                         400, r'The escape character \ cannot itself be escaped'
@@ -168,11 +167,6 @@ class FilterCollection(object):
         self.include_leave = filter_json.get("room", {}).get(
             "include_leave", False
         )
-        self._event_fields = filter_json.get("event_fields", [])
-        # Negative lookbehind assertion for '\'
-        # (?<!stuff) matches if the current position in the string is not preceded
-        # by a match for 'stuff'.
-        self._split_field_regex = re.compile(r'(?<!\\)\.')
 
     def __repr__(self):
         return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
@@ -207,54 +201,6 @@ class FilterCollection(object):
     def filter_room_account_data(self, events):
         return self._room_account_data.filter(self._room_filter.filter(events))
 
-    def filter_event_fields(self, event):
-        """Remove fields from an event in accordance with the 'event_fields' of a filter.
-
-        If there are no event fields specified then all fields are included.
-        The entries may include '.' charaters to indicate sub-fields.
-        So ['content.body'] will include the 'body' field of the 'content' object.
-        A literal '.' character in a field name may be escaped using a '\'.
-
-        Args:
-            event(dict): The raw event to filter
-        Returns:
-            dict: The same event with some fields missing, if required.
-        """
-        for field in self._event_fields:
-            self.filter_field(event, field)
-        return event
-
-    def filter_field(self, dictionary, field):
-        """Filter the given field from the given dictionary.
-
-        Args:
-            dictionary(dict): The dictionary to remove the field from.
-            field(str): The key to remove.
-        Returns:
-            dict: The same dictionary with the field removed.
-        """
-        # "content.body.thing\.with\.dots" => ["content", "body", "thing\.with\.dots"]
-        sub_fields = self._split_field_regex.split(field)
-        # remove escaping so we can use the right key names when deleting
-        sub_fields = [f.replace(r'\.', r'.') for f in sub_fields]
-
-        # common case e.g. 'origin_server_ts'
-        if len(sub_fields) == 1:
-            dictionary.pop(sub_fields[0], None)
-        # nested field e.g. 'content.body'
-        elif len(sub_fields) > 1:
-            # Pop the last field as that's the key to delete and we need the
-            # parent dict in order to remove the key. Drill down to the right dict.
-            key_to_delete = sub_fields.pop(-1)
-            sub_dict = dictionary
-            for sub_field in sub_fields:
-                if sub_field in sub_dict and type(sub_dict[sub_field]) == dict:
-                    sub_dict = sub_dict[sub_field]
-                else:
-                    return dictionary
-            sub_dict.pop(key_to_delete, None)
-        return dictionary
-
 
 class Filter(object):
     def __init__(self, filter_json):
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 0e9fd902af..4febd98f43 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -16,6 +16,15 @@
 from synapse.api.constants import EventTypes
 from . import EventBase
 
+import re
+
+# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
+# (?<!stuff) matches if the current position in the string is not preceded
+# by a match for 'stuff'.
+# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
+#       the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
+SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.')
+
 
 def prune_event(event):
     """ Returns a pruned version of the given event, which removes all keys we
@@ -97,6 +106,87 @@ def prune_event(event):
     )
 
 
+def _copy_field(src, dst, field):
+    """Copy the field in 'src' to 'dst'.
+
+    For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
+    then dst={"foo":{"bar":5}}.
+
+    Args:
+        src(dict): The dict to read from.
+        dst(dict): The dict to modify.
+        field(list<str>): List of keys to drill down to in 'src'.
+    """
+    if len(field) == 0:  # this should be impossible
+        return
+    if len(field) == 1:  # common case e.g. 'origin_server_ts'
+        if field[0] in src:
+            dst[field[0]] = src[field[0]]
+        return
+
+    # Else is a nested field e.g. 'content.body'
+    # Pop the last field as that's the key to move across and we need the
+    # parent dict in order to access the data. Drill down to the right dict.
+    key_to_move = field.pop(-1)
+    sub_dict = src
+    for sub_field in field:  # e.g. sub_field => "content"
+        if sub_field in sub_dict and type(sub_dict[sub_field]) == dict:
+            sub_dict = sub_dict[sub_field]
+        else:
+            return
+
+    if key_to_move not in sub_dict:
+        return
+
+    # Insert the key into the output dictionary, creating nested objects
+    # as required. We couldn't do this any earlier or else we'd need to delete
+    # the empty objects if the key didn't exist.
+    sub_out_dict = dst
+    for sub_field in field:
+        if sub_field not in sub_out_dict:
+            sub_out_dict[sub_field] = {}
+        sub_out_dict = sub_out_dict[sub_field]
+    sub_out_dict[key_to_move] = sub_dict[key_to_move]
+
+
+def only_fields(dictionary, fields):
+    """Return a new dict with only the fields in 'dictionary' which are present
+    in 'fields'.
+
+    If there are no event fields specified then all fields are included.
+    The entries may include '.' charaters to indicate sub-fields.
+    So ['content.body'] will include the 'body' field of the 'content' object.
+    A literal '.' character in a field name may be escaped using a '\'.
+
+    Args:
+        dictionary(dict): The dictionary to read from.
+        fields(list<str>): A list of fields to copy over. Only shallow refs are
+        taken.
+    Returns:
+        dict: A new dictionary with only the given fields. If fields was empty,
+        the same dictionary is returned.
+    """
+    if len(fields) == 0:
+        return dictionary
+
+    # for each field, convert it:
+    # ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]]
+    split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields]
+
+    # for each element of the output array of arrays:
+    # remove escaping so we can use the right key names. This purposefully avoids
+    # using list comprehensions to avoid needless allocations as this may be called
+    # on a lot of events.
+    for field_array in split_fields:
+        for i, field in enumerate(field_array):
+            field_array[i] = field.replace(r'\.', r'.')
+
+    output = {}
+    for field_array in split_fields:
+        _copy_field(dictionary, output, field_array)
+    return output
+
+
 def format_event_raw(d):
     return d
 
@@ -137,7 +227,7 @@ def format_event_for_client_v2_without_room_id(d):
 
 def serialize_event(e, time_now_ms, as_client_event=True,
                     event_format=format_event_for_client_v1,
-                    token_id=None):
+                    token_id=None, event_fields=None):
     # FIXME(erikj): To handle the case of presence events and the like
     if not isinstance(e, EventBase):
         return e
@@ -164,6 +254,9 @@ def serialize_event(e, time_now_ms, as_client_event=True,
                 d["unsigned"]["transaction_id"] = txn_id
 
     if as_client_event:
-        return event_format(d)
-    else:
-        return d
+        d = event_format(d)
+
+    if isinstance(event_fields, list):
+        d = only_fields(d, event_fields)
+
+    return d
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index fb0953c4ec..b9f55d174d 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -114,3 +114,24 @@ class PruneEventTestCase(unittest.TestCase):
                 'unsigned': {},
             }
         )
+
+
+class SerializeEventTestCase(unittest.TestCase):
+
+    def test_event_fields_works_with_keys(self):
+        pass
+
+    def test_event_fields_works_with_nested_keys(self):
+        pass
+
+    def test_event_fields_works_with_dot_keys(self):
+        pass
+
+    def test_event_fields_works_with_nested_dot_keys(self):
+        pass
+
+    def test_event_fields_nops_with_unknown_keys(self):
+        pass
+
+    def test_event_fields_nops_with_non_dict_keys(self):
+        pass