diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 52eb05bfbf..dd49a14524 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -314,3 +315,90 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 3)
+
+
+class UpsertManyTests(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.storage = hs.get_datastore()
+
+ self.table_name = "table_" + hs.get_secrets().token_hex(6)
+ self.get_success(
+ self.storage.runInteraction(
+ "create",
+ lambda x, *a: x.execute(*a),
+ "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
+ % (self.table_name,),
+ )
+ )
+ self.get_success(
+ self.storage.runInteraction(
+ "index",
+ lambda x, *a: x.execute(*a),
+ "CREATE UNIQUE INDEX %sindex ON %s(id, username)"
+ % (self.table_name, self.table_name),
+ )
+ )
+
+ def _dump_to_tuple(self, res):
+ for i in res:
+ yield (i["id"], i["username"], i["value"])
+
+ def test_upsert_many(self):
+ """
+ Upsert_many will perform the upsert operation across a batch of data.
+ """
+ # Add some data to an empty table
+ key_names = ["id", "username"]
+ value_names = ["value"]
+ key_values = [[1, "user1"], [2, "user2"]]
+ value_values = [["hello"], ["there"]]
+
+ self.get_success(
+ self.storage.runInteraction(
+ "test",
+ self.storage._simple_upsert_many_txn,
+ self.table_name,
+ key_names,
+ key_values,
+ value_names,
+ value_values,
+ )
+ )
+
+ # Check results are what we expect
+ res = self.get_success(
+ self.storage._simple_select_list(
+ self.table_name, None, ["id, username, value"]
+ )
+ )
+ self.assertEqual(
+ set(self._dump_to_tuple(res)),
+ set([(1, "user1", "hello"), (2, "user2", "there")]),
+ )
+
+ # Update only user2
+ key_values = [[2, "user2"]]
+ value_values = [["bleb"]]
+
+ self.get_success(
+ self.storage.runInteraction(
+ "test",
+ self.storage._simple_upsert_many_txn,
+ self.table_name,
+ key_names,
+ key_values,
+ value_names,
+ value_values,
+ )
+ )
+
+ # Check results are what we expect
+ res = self.get_success(
+ self.storage._simple_select_list(
+ self.table_name, None, ["id, username, value"]
+ )
+ )
+ self.assertEqual(
+ set(self._dump_to_tuple(res)),
+ set([(1, "user1", "hello"), (2, "user2", "bleb")]),
+ )
|