summary refs log tree commit diff
path: root/tests/events/test_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/events/test_utils.py')
-rw-r--r--tests/events/test_utils.py35
1 files changed, 30 insertions, 5 deletions
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index b1c47efac7..a79256846f 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -12,19 +12,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import unittest as stdlib_unittest
+
 from synapse.api.constants import EventContentFields
 from synapse.api.room_versions import RoomVersions
 from synapse.events import make_event_from_dict
 from synapse.events.utils import (
     SerializeEventConfig,
     copy_and_fixup_power_levels_contents,
+    maybe_upsert_event_field,
     prune_event,
     serialize_event,
 )
 from synapse.util.frozenutils import freeze
 
-from tests import unittest
-
 
 def MockEvent(**kwargs):
     if "event_id" not in kwargs:
@@ -34,7 +35,31 @@ def MockEvent(**kwargs):
     return make_event_from_dict(kwargs)
 
 
-class PruneEventTestCase(unittest.TestCase):
+class TestMaybeUpsertEventField(stdlib_unittest.TestCase):
+    def test_update_okay(self) -> None:
+        event = make_event_from_dict({"event_id": "$1234"})
+        success = maybe_upsert_event_field(event, event.unsigned, "key", "value")
+        self.assertTrue(success)
+        self.assertEqual(event.unsigned["key"], "value")
+
+    def test_update_not_okay(self) -> None:
+        event = make_event_from_dict({"event_id": "$1234"})
+        LARGE_STRING = "a" * 100_000
+        success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
+        self.assertFalse(success)
+        self.assertNotIn("key", event.unsigned)
+
+    def test_update_not_okay_leaves_original_value(self) -> None:
+        event = make_event_from_dict(
+            {"event_id": "$1234", "unsigned": {"key": "value"}}
+        )
+        LARGE_STRING = "a" * 100_000
+        success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
+        self.assertFalse(success)
+        self.assertEqual(event.unsigned["key"], "value")
+
+
+class PruneEventTestCase(stdlib_unittest.TestCase):
     def run_test(self, evdict, matchdict, **kwargs):
         """
         Asserts that a new event constructed with `evdict` will look like
@@ -391,7 +416,7 @@ class PruneEventTestCase(unittest.TestCase):
         )
 
 
-class SerializeEventTestCase(unittest.TestCase):
+class SerializeEventTestCase(stdlib_unittest.TestCase):
     def serialize(self, ev, fields):
         return serialize_event(
             ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
@@ -513,7 +538,7 @@ class SerializeEventTestCase(unittest.TestCase):
             )
 
 
-class CopyPowerLevelsContentTestCase(unittest.TestCase):
+class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
     def setUp(self) -> None:
         self.test_content = {
             "ban": 50,