summary refs log tree commit diff
path: root/tests/storage/test__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test__init__.py')
-rw-r--r--tests/storage/test__init__.py45
1 files changed, 20 insertions, 25 deletions
diff --git a/tests/storage/test__init__.py b/tests/storage/test__init__.py
index fe6eeeaf10..acc59e4689 100644
--- a/tests/storage/test__init__.py
+++ b/tests/storage/test__init__.py
@@ -15,7 +15,7 @@
 
 from twisted.internet import defer
 
-import tests.unittest
+
 import tests.utils
 
 
@@ -33,39 +33,34 @@ class InitTestCase(tests.unittest.TestCase):
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
 
+    @defer.inlineCallbacks
     def test_count_monthly_users(self):
-        count = self.store.count_monthly_users()
+        count = yield self.store.count_monthly_users()
         self.assertEqual(0, count)
 
-        self._insert_user_ips("@user:server1")
-        self._insert_user_ips("@user:server2")
+        yield self._insert_user_ips("@user:server1")
+        yield self._insert_user_ips("@user:server2")
 
-        count = self.store.count_monthly_users()
+        count = yield self.store.count_monthly_users()
         self.assertEqual(2, count)
 
+    @defer.inlineCallbacks
     def _insert_user_ips(self, user):
         """
         Helper function to populate user_ips without using batch insertion infra
         args:
             user (str):  specify username i.e. @user:server.com
         """
-        try:
-            txn = self.store.db_conn.cursor()
-            self.store.database_engine.lock_table(txn, "user_ips")
-            self.store._simple_upsert_txn(
-                txn,
-                table="user_ips",
-                keyvalues={
-                    "user_id": user,
-                    "access_token": "access_token",
-                    "ip": "ip",
-                    "user_agent": "user_agent",
-                    "device_id": "device_id",
-                },
-                values={
-                    "last_seen": self.clock.time_msec(),
-                },
-                lock=False,
-            )
-        finally:
-            txn.close()
+        yield self.store._simple_upsert(
+            table="user_ips",
+            keyvalues={
+                "user_id": user,
+                "access_token": "access_token",
+                "ip": "ip",
+                "user_agent": "user_agent",
+                "device_id": "device_id",
+            },
+            values={
+                "last_seen": self.clock.time_msec(),
+            }
+        )