summary refs log tree commit diff
path: root/tests/storage/test__base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test__base.py')
-rw-r--r--tests/storage/test__base.py73
1 files changed, 57 insertions, 16 deletions
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 366398e39d..09cb06d614 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import secrets
-from typing import Any, Dict, Generator, List, Tuple
+from typing import Generator, Tuple
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -24,7 +24,7 @@ from synapse.util import Clock
 from tests import unittest
 
 
-class UpsertManyTests(unittest.HomeserverTestCase):
+class UpdateUpsertManyTests(unittest.HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.storage = hs.get_datastores().main
 
@@ -46,9 +46,13 @@ class UpsertManyTests(unittest.HomeserverTestCase):
             )
         )
 
-    def _dump_to_tuple(
-        self, res: List[Dict[str, Any]]
-    ) -> Generator[Tuple[int, str, str], None, None]:
+    def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]:
+        res = self.get_success(
+            self.storage.db_pool.simple_select_list(
+                self.table_name, None, ["id, username, value"]
+            )
+        )
+
         for i in res:
             yield (i["id"], i["username"], i["value"])
 
@@ -75,13 +79,8 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         )
 
         # Check results are what we expect
-        res = self.get_success(
-            self.storage.db_pool.simple_select_list(
-                self.table_name, None, ["id, username, value"]
-            )
-        )
         self.assertEqual(
-            set(self._dump_to_tuple(res)),
+            set(self._dump_table_to_tuple()),
             {(1, "user1", "hello"), (2, "user2", "there")},
         )
 
@@ -102,12 +101,54 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         )
 
         # Check results are what we expect
-        res = self.get_success(
-            self.storage.db_pool.simple_select_list(
-                self.table_name, None, ["id, username, value"]
+        self.assertEqual(
+            set(self._dump_table_to_tuple()),
+            {(1, "user1", "hello"), (2, "user2", "bleb")},
+        )
+
+    def test_simple_update_many(self):
+        """
+        simple_update_many performs many updates at once.
+        """
+        # First add some data.
+        self.get_success(
+            self.storage.db_pool.simple_insert_many(
+                table=self.table_name,
+                keys=("id", "username", "value"),
+                values=[(1, "alice", "A"), (2, "bob", "B"), (3, "charlie", "C")],
+                desc="insert",
             )
         )
+
+        # Check the data made it to the table
         self.assertEqual(
-            set(self._dump_to_tuple(res)),
-            {(1, "user1", "hello"), (2, "user2", "bleb")},
+            set(self._dump_table_to_tuple()),
+            {(1, "alice", "A"), (2, "bob", "B"), (3, "charlie", "C")},
+        )
+
+        # Now use simple_update_many
+        self.get_success(
+            self.storage.db_pool.simple_update_many(
+                table=self.table_name,
+                key_names=("username",),
+                key_values=(
+                    ("alice",),
+                    ("bob",),
+                    ("stranger",),
+                ),
+                value_names=("value",),
+                value_values=(
+                    ("aaa!",),
+                    ("bbb!",),
+                    ("???",),
+                ),
+                desc="update_many1",
+            )
+        )
+
+        # Check the table is how we expect:
+        # charlie has been left alone
+        self.assertEqual(
+            set(self._dump_table_to_tuple()),
+            {(1, "alice", "aaa!"), (2, "bob", "bbb!"), (3, "charlie", "C")},
         )