diff options
Diffstat (limited to 'tests/storage/test__base.py')
-rw-r--r-- | tests/storage/test__base.py | 73 |
1 files changed, 57 insertions, 16 deletions
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 366398e39d..09cb06d614 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -14,7 +14,7 @@ # limitations under the License. import secrets -from typing import Any, Dict, Generator, List, Tuple +from typing import Generator, Tuple from twisted.test.proto_helpers import MemoryReactor @@ -24,7 +24,7 @@ from synapse.util import Clock from tests import unittest -class UpsertManyTests(unittest.HomeserverTestCase): +class UpdateUpsertManyTests(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.storage = hs.get_datastores().main @@ -46,9 +46,13 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) ) - def _dump_to_tuple( - self, res: List[Dict[str, Any]] - ) -> Generator[Tuple[int, str, str], None, None]: + def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]: + res = self.get_success( + self.storage.db_pool.simple_select_list( + self.table_name, None, ["id, username, value"] + ) + ) + for i in res: yield (i["id"], i["username"], i["value"]) @@ -75,13 +79,8 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) # Check results are what we expect - res = self.get_success( - self.storage.db_pool.simple_select_list( - self.table_name, None, ["id, username, value"] - ) - ) self.assertEqual( - set(self._dump_to_tuple(res)), + set(self._dump_table_to_tuple()), {(1, "user1", "hello"), (2, "user2", "there")}, ) @@ -102,12 +101,54 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) # Check results are what we expect - res = self.get_success( - self.storage.db_pool.simple_select_list( - self.table_name, None, ["id, username, value"] + self.assertEqual( + set(self._dump_table_to_tuple()), + {(1, "user1", "hello"), (2, "user2", "bleb")}, + ) + + def test_simple_update_many(self): + """ + simple_update_many performs many updates at once. + """ + # First add some data. + self.get_success( + self.storage.db_pool.simple_insert_many( + table=self.table_name, + keys=("id", "username", "value"), + values=[(1, "alice", "A"), (2, "bob", "B"), (3, "charlie", "C")], + desc="insert", ) ) + + # Check the data made it to the table self.assertEqual( - set(self._dump_to_tuple(res)), - {(1, "user1", "hello"), (2, "user2", "bleb")}, + set(self._dump_table_to_tuple()), + {(1, "alice", "A"), (2, "bob", "B"), (3, "charlie", "C")}, + ) + + # Now use simple_update_many + self.get_success( + self.storage.db_pool.simple_update_many( + table=self.table_name, + key_names=("username",), + key_values=( + ("alice",), + ("bob",), + ("stranger",), + ), + value_names=("value",), + value_values=( + ("aaa!",), + ("bbb!",), + ("???",), + ), + desc="update_many1", + ) + ) + + # Check the table is how we expect: + # charlie has been left alone + self.assertEqual( + set(self._dump_table_to_tuple()), + {(1, "alice", "aaa!"), (2, "bob", "bbb!"), (3, "charlie", "C")}, ) |