diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 366398e39d..09cb06d614 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import secrets
-from typing import Any, Dict, Generator, List, Tuple
+from typing import Generator, Tuple
from twisted.test.proto_helpers import MemoryReactor
@@ -24,7 +24,7 @@ from synapse.util import Clock
from tests import unittest
-class UpsertManyTests(unittest.HomeserverTestCase):
+class UpdateUpsertManyTests(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.storage = hs.get_datastores().main
@@ -46,9 +46,13 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
)
- def _dump_to_tuple(
- self, res: List[Dict[str, Any]]
- ) -> Generator[Tuple[int, str, str], None, None]:
+ def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]:
+ res = self.get_success(
+ self.storage.db_pool.simple_select_list(
+ self.table_name, None, ["id, username, value"]
+ )
+ )
+
for i in res:
yield (i["id"], i["username"], i["value"])
@@ -75,13 +79,8 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
# Check results are what we expect
- res = self.get_success(
- self.storage.db_pool.simple_select_list(
- self.table_name, None, ["id, username, value"]
- )
- )
self.assertEqual(
- set(self._dump_to_tuple(res)),
+ set(self._dump_table_to_tuple()),
{(1, "user1", "hello"), (2, "user2", "there")},
)
@@ -102,12 +101,54 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
# Check results are what we expect
- res = self.get_success(
- self.storage.db_pool.simple_select_list(
- self.table_name, None, ["id, username, value"]
+ self.assertEqual(
+ set(self._dump_table_to_tuple()),
+ {(1, "user1", "hello"), (2, "user2", "bleb")},
+ )
+
+ def test_simple_update_many(self):
+ """
+ simple_update_many performs many updates at once.
+ """
+ # First add some data.
+ self.get_success(
+ self.storage.db_pool.simple_insert_many(
+ table=self.table_name,
+ keys=("id", "username", "value"),
+ values=[(1, "alice", "A"), (2, "bob", "B"), (3, "charlie", "C")],
+ desc="insert",
)
)
+
+ # Check the data made it to the table
self.assertEqual(
- set(self._dump_to_tuple(res)),
- {(1, "user1", "hello"), (2, "user2", "bleb")},
+ set(self._dump_table_to_tuple()),
+ {(1, "alice", "A"), (2, "bob", "B"), (3, "charlie", "C")},
+ )
+
+ # Now use simple_update_many
+ self.get_success(
+ self.storage.db_pool.simple_update_many(
+ table=self.table_name,
+ key_names=("username",),
+ key_values=(
+ ("alice",),
+ ("bob",),
+ ("stranger",),
+ ),
+ value_names=("value",),
+ value_values=(
+ ("aaa!",),
+ ("bbb!",),
+ ("???",),
+ ),
+ desc="update_many1",
+ )
+ )
+
+ # Check the table is how we expect:
+ # charlie has been left alone
+ self.assertEqual(
+ set(self._dump_table_to_tuple()),
+ {(1, "alice", "aaa!"), (2, "bob", "bbb!"), (3, "charlie", "C")},
)
|