diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index f4b21ca517..5bbaef8187 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -145,9 +145,7 @@ def _copy_field(src, dst, field):
# the empty objects if the key didn't exist.
sub_out_dict = dst
for sub_field in field:
- if sub_field not in sub_out_dict:
- sub_out_dict[sub_field] = {}
- sub_out_dict = sub_out_dict[sub_field]
+ sub_out_dict = sub_out_dict.setdefault(sub_field, {})
sub_out_dict[key_to_move] = sub_dict[key_to_move]
@@ -176,12 +174,10 @@ def only_fields(dictionary, fields):
split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields]
# for each element of the output array of arrays:
- # remove escaping so we can use the right key names. This purposefully avoids
- # using list comprehensions to avoid needless allocations as this may be called
- # on a lot of events.
- for field_array in split_fields:
- for i, field in enumerate(field_array):
- field_array[i] = field.replace(r'\.', r'.')
+ # remove escaping so we can use the right key names.
+ split_fields[:] = [
+ [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
+ ]
output = {}
for field_array in split_fields:
@@ -258,8 +254,10 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if as_client_event:
d = event_format(d)
- if (only_event_fields and isinstance(only_event_fields, list) and
- all(isinstance(f, basestring) for f in only_event_fields)):
+ if only_event_fields:
+ if (not isinstance(only_event_fields, list) or
+ not all(isinstance(f, basestring) for f in only_event_fields)):
+ raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields)
return d
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index 5b3326ce8d..29f068d1f1 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -272,7 +272,7 @@ class SerializeEventTestCase(unittest.TestCase):
)
def test_event_fields_fail_if_fields_not_str(self):
- self.assertEquals(
+ with self.assertRaises(TypeError):
self.serialize(
MockEvent(
room_id="!foo:bar",
@@ -281,12 +281,4 @@ class SerializeEventTestCase(unittest.TestCase):
},
),
["room_id", 4]
- ),
- {
- "room_id": "!foo:bar",
- "content": {
- "foo": "bar",
- },
- "unsigned": {}
- }
- )
+ )
|