diff --git a/tests/storage/test__init__.py b/tests/storage/test__init__.py
index fe6eeeaf10..acc59e4689 100644
--- a/tests/storage/test__init__.py
+++ b/tests/storage/test__init__.py
@@ -15,7 +15,7 @@
from twisted.internet import defer
-import tests.unittest
+
import tests.utils
@@ -33,39 +33,34 @@ class InitTestCase(tests.unittest.TestCase):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
+ @defer.inlineCallbacks
def test_count_monthly_users(self):
- count = self.store.count_monthly_users()
+ count = yield self.store.count_monthly_users()
self.assertEqual(0, count)
- self._insert_user_ips("@user:server1")
- self._insert_user_ips("@user:server2")
+ yield self._insert_user_ips("@user:server1")
+ yield self._insert_user_ips("@user:server2")
- count = self.store.count_monthly_users()
+ count = yield self.store.count_monthly_users()
self.assertEqual(2, count)
+ @defer.inlineCallbacks
def _insert_user_ips(self, user):
"""
Helper function to populate user_ips without using batch insertion infra
args:
user (str): specify username i.e. @user:server.com
"""
- try:
- txn = self.store.db_conn.cursor()
- self.store.database_engine.lock_table(txn, "user_ips")
- self.store._simple_upsert_txn(
- txn,
- table="user_ips",
- keyvalues={
- "user_id": user,
- "access_token": "access_token",
- "ip": "ip",
- "user_agent": "user_agent",
- "device_id": "device_id",
- },
- values={
- "last_seen": self.clock.time_msec(),
- },
- lock=False,
- )
- finally:
- txn.close()
+ yield self.store._simple_upsert(
+ table="user_ips",
+ keyvalues={
+ "user_id": user,
+ "access_token": "access_token",
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "device_id": "device_id",
+ },
+ values={
+ "last_seen": self.clock.time_msec(),
+ }
+ )
|