summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/events/utils.py20
-rw-r--r--tests/events/test_utils.py12
2 files changed, 11 insertions, 21 deletions
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index f4b21ca517..5bbaef8187 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -145,9 +145,7 @@ def _copy_field(src, dst, field):
     # the empty objects if the key didn't exist.
     sub_out_dict = dst
     for sub_field in field:
-        if sub_field not in sub_out_dict:
-            sub_out_dict[sub_field] = {}
-        sub_out_dict = sub_out_dict[sub_field]
+        sub_out_dict = sub_out_dict.setdefault(sub_field, {})
     sub_out_dict[key_to_move] = sub_dict[key_to_move]
 
 
@@ -176,12 +174,10 @@ def only_fields(dictionary, fields):
     split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields]
 
     # for each element of the output array of arrays:
-    # remove escaping so we can use the right key names. This purposefully avoids
-    # using list comprehensions to avoid needless allocations as this may be called
-    # on a lot of events.
-    for field_array in split_fields:
-        for i, field in enumerate(field_array):
-            field_array[i] = field.replace(r'\.', r'.')
+    # remove escaping so we can use the right key names.
+    split_fields[:] = [
+        [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
+    ]
 
     output = {}
     for field_array in split_fields:
@@ -258,8 +254,10 @@ def serialize_event(e, time_now_ms, as_client_event=True,
     if as_client_event:
         d = event_format(d)
 
-    if (only_event_fields and isinstance(only_event_fields, list) and
-            all(isinstance(f, basestring) for f in only_event_fields)):
+    if only_event_fields:
+        if (not isinstance(only_event_fields, list) or
+                not all(isinstance(f, basestring) for f in only_event_fields)):
+            raise TypeError("only_event_fields must be a list of strings")
         d = only_fields(d, only_event_fields)
 
     return d
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index 5b3326ce8d..29f068d1f1 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -272,7 +272,7 @@ class SerializeEventTestCase(unittest.TestCase):
         )
 
     def test_event_fields_fail_if_fields_not_str(self):
-        self.assertEquals(
+        with self.assertRaises(TypeError):
             self.serialize(
                 MockEvent(
                     room_id="!foo:bar",
@@ -281,12 +281,4 @@ class SerializeEventTestCase(unittest.TestCase):
                     },
                 ),
                 ["room_id", 4]
-            ),
-            {
-                "room_id": "!foo:bar",
-                "content": {
-                    "foo": "bar",
-                },
-                "unsigned": {}
-            }
-        )
+            )