diff options
-rw-r--r-- | synapse/api/filtering.py | 16 | ||||
-rw-r--r-- | synapse/events/utils.py | 102 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/sync.py | 23 | ||||
-rw-r--r-- | tests/events/test_utils.py | 170 |
4 files changed, 296 insertions, 15 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 3b3ef70750..6f16e4d406 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -71,6 +71,21 @@ class Filtering(object): if key in user_filter_json["room"]: self._check_definition(user_filter_json["room"][key]) + if "event_fields" in user_filter_json: + if type(user_filter_json["event_fields"]) != list: + raise SynapseError(400, "event_fields must be a list of strings") + for field in user_filter_json["event_fields"]: + if not isinstance(field, basestring): + raise SynapseError(400, "Event field must be a string") + # Don't allow '\\' in event field filters. This makes matching + # events a lot easier as we can then use a negative lookbehind + # assertion to split '\.' If we allowed \\ then it would + # incorrectly split '\\.' See synapse.events.utils.serialize_event + if r'\\' in field: + raise SynapseError( + 400, r'The escape character \ cannot itself be escaped' + ) + def _check_definition_room_lists(self, definition): """Check that "rooms" and "not_rooms" are lists of room ids if they are present @@ -152,6 +167,7 @@ class FilterCollection(object): self.include_leave = filter_json.get("room", {}).get( "include_leave", False ) + self.event_fields = filter_json.get("event_fields", []) def __repr__(self): return "<FilterCollection %s>" % (json.dumps(self._filter_json),) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 0e9fd902af..5bbaef8187 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -16,6 +16,17 @@ from synapse.api.constants import EventTypes from . import EventBase +from frozendict import frozendict + +import re + +# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' +# (?<!stuff) matches if the current position in the string is not preceded +# by a match for 'stuff'. +# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as +# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar" +SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.') + def prune_event(event): """ Returns a pruned version of the given event, which removes all keys we @@ -97,6 +108,83 @@ def prune_event(event): ) +def _copy_field(src, dst, field): + """Copy the field in 'src' to 'dst'. + + For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"] + then dst={"foo":{"bar":5}}. + + Args: + src(dict): The dict to read from. + dst(dict): The dict to modify. + field(list<str>): List of keys to drill down to in 'src'. + """ + if len(field) == 0: # this should be impossible + return + if len(field) == 1: # common case e.g. 'origin_server_ts' + if field[0] in src: + dst[field[0]] = src[field[0]] + return + + # Else is a nested field e.g. 'content.body' + # Pop the last field as that's the key to move across and we need the + # parent dict in order to access the data. Drill down to the right dict. + key_to_move = field.pop(-1) + sub_dict = src + for sub_field in field: # e.g. sub_field => "content" + if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]: + sub_dict = sub_dict[sub_field] + else: + return + + if key_to_move not in sub_dict: + return + + # Insert the key into the output dictionary, creating nested objects + # as required. We couldn't do this any earlier or else we'd need to delete + # the empty objects if the key didn't exist. + sub_out_dict = dst + for sub_field in field: + sub_out_dict = sub_out_dict.setdefault(sub_field, {}) + sub_out_dict[key_to_move] = sub_dict[key_to_move] + + +def only_fields(dictionary, fields): + """Return a new dict with only the fields in 'dictionary' which are present + in 'fields'. + + If there are no event fields specified then all fields are included. + The entries may include '.' charaters to indicate sub-fields. + So ['content.body'] will include the 'body' field of the 'content' object. + A literal '.' character in a field name may be escaped using a '\'. + + Args: + dictionary(dict): The dictionary to read from. + fields(list<str>): A list of fields to copy over. Only shallow refs are + taken. + Returns: + dict: A new dictionary with only the given fields. If fields was empty, + the same dictionary is returned. + """ + if len(fields) == 0: + return dictionary + + # for each field, convert it: + # ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]] + split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields] + + # for each element of the output array of arrays: + # remove escaping so we can use the right key names. + split_fields[:] = [ + [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields + ] + + output = {} + for field_array in split_fields: + _copy_field(dictionary, output, field_array) + return output + + def format_event_raw(d): return d @@ -137,7 +225,7 @@ def format_event_for_client_v2_without_room_id(d): def serialize_event(e, time_now_ms, as_client_event=True, event_format=format_event_for_client_v1, - token_id=None): + token_id=None, only_event_fields=None): # FIXME(erikj): To handle the case of presence events and the like if not isinstance(e, EventBase): return e @@ -164,6 +252,12 @@ def serialize_event(e, time_now_ms, as_client_event=True, d["unsigned"]["transaction_id"] = txn_id if as_client_event: - return event_format(d) - else: - return d + d = event_format(d) + + if only_event_fields: + if (not isinstance(only_event_fields, list) or + not all(isinstance(f, basestring) for f in only_event_fields)): + raise TypeError("only_event_fields must be a list of strings") + d = only_fields(d, only_event_fields) + + return d diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 6fc63715aa..7199ec883a 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -162,7 +162,7 @@ class SyncRestServlet(RestServlet): time_now = self.clock.time_msec() joined = self.encode_joined( - sync_result.joined, time_now, requester.access_token_id + sync_result.joined, time_now, requester.access_token_id, filter.event_fields ) invited = self.encode_invited( @@ -170,7 +170,7 @@ class SyncRestServlet(RestServlet): ) archived = self.encode_archived( - sync_result.archived, time_now, requester.access_token_id + sync_result.archived, time_now, requester.access_token_id, filter.event_fields ) response_content = { @@ -197,7 +197,7 @@ class SyncRestServlet(RestServlet): formatted.append(event) return {"events": formatted} - def encode_joined(self, rooms, time_now, token_id): + def encode_joined(self, rooms, time_now, token_id, event_fields): """ Encode the joined rooms in a sync result @@ -208,7 +208,8 @@ class SyncRestServlet(RestServlet): calculations token_id(int): ID of the user's auth token - used for namespacing of transaction IDs - + event_fields(list<str>): List of event fields to include. If empty, + all fields will be returned. Returns: dict[str, dict[str, object]]: the joined rooms list, in our response format @@ -216,7 +217,7 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = self.encode_room( - room, time_now, token_id + room, time_now, token_id, only_fields=event_fields ) return joined @@ -253,7 +254,7 @@ class SyncRestServlet(RestServlet): return invited - def encode_archived(self, rooms, time_now, token_id): + def encode_archived(self, rooms, time_now, token_id, event_fields): """ Encode the archived rooms in a sync result @@ -264,7 +265,8 @@ class SyncRestServlet(RestServlet): calculations token_id(int): ID of the user's auth token - used for namespacing of transaction IDs - + event_fields(list<str>): List of event fields to include. If empty, + all fields will be returned. Returns: dict[str, dict[str, object]]: The invited rooms list, in our response format @@ -272,13 +274,13 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = self.encode_room( - room, time_now, token_id, joined=False + room, time_now, token_id, joined=False, only_fields=event_fields ) return joined @staticmethod - def encode_room(room, time_now, token_id, joined=True): + def encode_room(room, time_now, token_id, joined=True, only_fields=None): """ Args: room (JoinedSyncResult|ArchivedSyncResult): sync result for a @@ -289,7 +291,7 @@ class SyncRestServlet(RestServlet): of transaction IDs joined (bool): True if the user is joined to this room - will mean we handle ephemeral events - + only_fields(list<str>): Optional. The list of event fields to include. Returns: dict[str, object]: the room, encoded in our response format """ @@ -298,6 +300,7 @@ class SyncRestServlet(RestServlet): return serialize_event( event, time_now, token_id=token_id, event_format=format_event_for_client_v2_without_room_id, + only_event_fields=only_fields, ) state_dict = room.state diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index fb0953c4ec..29f068d1f1 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -17,7 +17,11 @@ from .. import unittest from synapse.events import FrozenEvent -from synapse.events.utils import prune_event +from synapse.events.utils import prune_event, serialize_event + + +def MockEvent(**kwargs): + return FrozenEvent(kwargs) class PruneEventTestCase(unittest.TestCase): @@ -114,3 +118,167 @@ class PruneEventTestCase(unittest.TestCase): 'unsigned': {}, } ) + + +class SerializeEventTestCase(unittest.TestCase): + + def serialize(self, ev, fields): + return serialize_event(ev, 1479807801915, only_event_fields=fields) + + def test_event_fields_works_with_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar" + ), + ["room_id"] + ), + { + "room_id": "!foo:bar", + } + ) + + def test_event_fields_works_with_nested_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar", + content={ + "body": "A message", + }, + ), + ["content.body"] + ), + { + "content": { + "body": "A message", + } + } + ) + + def test_event_fields_works_with_dot_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar", + content={ + "key.with.dots": {}, + }, + ), + ["content.key\.with\.dots"] + ), + { + "content": { + "key.with.dots": {}, + } + } + ) + + def test_event_fields_works_with_nested_dot_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar", + content={ + "not_me": 1, + "nested.dot.key": { + "leaf.key": 42, + "not_me_either": 1, + }, + }, + ), + ["content.nested\.dot\.key.leaf\.key"] + ), + { + "content": { + "nested.dot.key": { + "leaf.key": 42, + }, + } + } + ) + + def test_event_fields_nops_with_unknown_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar", + content={ + "foo": "bar", + }, + ), + ["content.foo", "content.notexists"] + ), + { + "content": { + "foo": "bar", + } + } + ) + + def test_event_fields_nops_with_non_dict_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar", + content={ + "foo": ["I", "am", "an", "array"], + }, + ), + ["content.foo.am"] + ), + {} + ) + + def test_event_fields_nops_with_array_keys(self): + self.assertEquals( + self.serialize( + MockEvent( + sender="@alice:localhost", + room_id="!foo:bar", + content={ + "foo": ["I", "am", "an", "array"], + }, + ), + ["content.foo.1"] + ), + {} + ) + + def test_event_fields_all_fields_if_empty(self): + self.assertEquals( + self.serialize( + MockEvent( + room_id="!foo:bar", + content={ + "foo": "bar", + }, + ), + [] + ), + { + "room_id": "!foo:bar", + "content": { + "foo": "bar", + }, + "unsigned": {} + } + ) + + def test_event_fields_fail_if_fields_not_str(self): + with self.assertRaises(TypeError): + self.serialize( + MockEvent( + room_id="!foo:bar", + content={ + "foo": "bar", + }, + ), + ["room_id", 4] + ) |