diff options
-rw-r--r-- | synapse/events/utils.py | 20 | ||||
-rw-r--r-- | tests/events/test_utils.py | 12 |
2 files changed, 11 insertions, 21 deletions
diff --git a/synapse/events/utils.py b/synapse/events/utils.py index f4b21ca517..5bbaef8187 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -145,9 +145,7 @@ def _copy_field(src, dst, field): # the empty objects if the key didn't exist. sub_out_dict = dst for sub_field in field: - if sub_field not in sub_out_dict: - sub_out_dict[sub_field] = {} - sub_out_dict = sub_out_dict[sub_field] + sub_out_dict = sub_out_dict.setdefault(sub_field, {}) sub_out_dict[key_to_move] = sub_dict[key_to_move] @@ -176,12 +174,10 @@ def only_fields(dictionary, fields): split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields] # for each element of the output array of arrays: - # remove escaping so we can use the right key names. This purposefully avoids - # using list comprehensions to avoid needless allocations as this may be called - # on a lot of events. - for field_array in split_fields: - for i, field in enumerate(field_array): - field_array[i] = field.replace(r'\.', r'.') + # remove escaping so we can use the right key names. + split_fields[:] = [ + [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields + ] output = {} for field_array in split_fields: @@ -258,8 +254,10 @@ def serialize_event(e, time_now_ms, as_client_event=True, if as_client_event: d = event_format(d) - if (only_event_fields and isinstance(only_event_fields, list) and - all(isinstance(f, basestring) for f in only_event_fields)): + if only_event_fields: + if (not isinstance(only_event_fields, list) or + not all(isinstance(f, basestring) for f in only_event_fields)): + raise TypeError("only_event_fields must be a list of strings") d = only_fields(d, only_event_fields) return d diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 5b3326ce8d..29f068d1f1 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -272,7 +272,7 @@ class SerializeEventTestCase(unittest.TestCase): ) def test_event_fields_fail_if_fields_not_str(self): - self.assertEquals( + with self.assertRaises(TypeError): self.serialize( MockEvent( room_id="!foo:bar", @@ -281,12 +281,4 @@ class SerializeEventTestCase(unittest.TestCase): }, ), ["room_id", 4] - ), - { - "room_id": "!foo:bar", - "content": { - "foo": "bar", - }, - "unsigned": {} - } - ) + ) |